Coverage for src / updates2mqtt / helpers.py: 88%

172 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-03 23:58 +0000

1import datetime as dt 

2import re 

3import time 

4from threading import Event 

5from typing import Any 

6from urllib.parse import urlparse 

7 

8import structlog 

9from hishel import CacheOptions, SpecificationPolicy # pyright: ignore[reportAttributeAccessIssue] 

10from hishel.httpx import SyncCacheClient 

11from httpx import Response 

12from tzlocal import get_localzone 

13 

14from updates2mqtt.config import Selector 

15 

16log = structlog.get_logger() 

17 

18 

19def timestamp(time_value: float | None) -> str | None: 

20 if time_value is None: 

21 return None 

22 try: 

23 return dt.datetime.fromtimestamp(time_value, tz=get_localzone()).isoformat() 

24 except: # noqa: E722 

25 return None 

26 

27 

28class Selection: 

29 def __init__(self, selector: Selector, value: str | None) -> None: 

30 self.result: bool = True 

31 self.matched: str | None = None 

32 if value is None: 

33 self.result = selector.include is None 

34 return 

35 if selector.exclude is not None: 

36 self.result = True 

37 if any(re.search(pat, value) for pat in selector.exclude): 

38 self.matched = value 

39 self.result = False 

40 if selector.include is not None: 

41 self.result = False 

42 if any(re.search(pat, value) for pat in selector.include): 

43 self.matched = value 

44 self.result = True 

45 

46 def __bool__(self) -> bool: 

47 """Expose the actual boolean so objects can be appropriately truthy""" 

48 return self.result 

49 

50 

51class ThrottledError(Exception): 

52 def __init__(self, message: str, retry_secs: int) -> None: 

53 super().__init__(message) 

54 self.retry_secs = retry_secs 

55 

56 

57class Throttler: 

58 DEFAULT_SITE = "DEFAULT_SITE" 

59 

60 def __init__(self, api_throttle_pause: int = 30, logger: Any | None = None, semaphore: Event | None = None) -> None: 

61 self.log: Any = logger or log 

62 self.pause_api_until: dict[str, float] = {} 

63 self.api_throttle_pause: int = api_throttle_pause 

64 self.semaphore = semaphore 

65 

66 def check_throttle(self, index_name: str | None = None) -> bool: 

67 if self.semaphore and self.semaphore.is_set(): 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 return True 

69 index_name = index_name or self.DEFAULT_SITE 

70 if self.pause_api_until.get(index_name) is not None: 

71 if self.pause_api_until[index_name] < time.time(): 

72 del self.pause_api_until[index_name] 

73 self.log.info("%s throttling wait complete", index_name) 

74 else: 

75 self.log.debug("%s throttling has %0.3f secs left", index_name, self.pause_api_until[index_name] - time.time()) 

76 return True 

77 return False 

78 

79 def throttle( 

80 self, 

81 index_name: str | None = None, 

82 retry_secs: int | None = None, 

83 explanation: str | None = None, 

84 raise_exception: bool = False, 

85 ) -> None: 

86 index_name = index_name or self.DEFAULT_SITE 

87 retry_secs = retry_secs if retry_secs and retry_secs > 0 else self.api_throttle_pause 

88 self.log.warn("%s throttling requests for %s seconds, %s", index_name, retry_secs, explanation) 

89 self.pause_api_until[index_name] = time.time() + retry_secs 

90 if raise_exception: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 raise ThrottledError(explanation or f"{index_name} throttled request", retry_secs) 

92 

93 

94class CacheMetadata: 

95 """Cache metadata extracted from hishel response extensions""" 

96 

97 def __init__(self, response: Response) -> None: 

98 self.from_cache: bool = response.extensions.get("hishel_from_cache", False) 

99 self.revalidated: bool = response.extensions.get("hishel_revalidated", False) 

100 self.created_at: float | None = response.extensions.get("hishel_created_at") 

101 self.stored: bool = response.extensions.get("hishel_stored", False) 

102 self.age: float | None = None 

103 if self.created_at is not None: 

104 self.age = time.time() - self.created_at 

105 

106 def __str__(self) -> str: 

107 """Summarize in a string""" 

108 return f"cached: {self.from_cache}, revalidated: {self.revalidated}, age:{self.age}, stored:{self.stored}" 

109 

110 

111class APIStats: 

112 def __init__(self) -> None: 

113 self.fetches: int = 0 

114 self.cached: int = 0 

115 self.revalidated: int = 0 

116 self.failed: dict[int, int] = {} 

117 self.elapsed: float = 0 

118 self.max_cache_age: float = 0 

119 

120 def tick(self, response: Response | None) -> None: 

121 self.fetches += 1 

122 if response is None: 

123 self.failed.setdefault(0, 0) 

124 self.failed[0] += 1 

125 return 

126 cache_metadata: CacheMetadata = CacheMetadata(response) 

127 self.cached += 1 if cache_metadata.from_cache else 0 

128 self.revalidated += 1 if cache_metadata.revalidated else 0 

129 if response.elapsed: 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true

130 self.elapsed += response.elapsed.microseconds / 1000000 

131 self.elapsed += response.elapsed.seconds 

132 if not response.is_success: 

133 self.failed.setdefault(response.status_code, 0) 

134 self.failed[response.status_code] += 1 

135 if cache_metadata.age is not None and (self.max_cache_age is None or cache_metadata.age > self.max_cache_age): 

136 self.max_cache_age = cache_metadata.age 

137 

138 def hit_ratio(self) -> float: 

139 return round(self.cached / self.fetches, 2) if self.cached and self.fetches else 0 

140 

141 def average_elapsed(self) -> float: 

142 return round(self.elapsed / self.fetches, 2) if self.elapsed and self.fetches else 0 

143 

144 def __str__(self) -> str: 

145 """Log line friendly string summary""" 

146 return ( 

147 f"fetches: {self.fetches}, cache ratio: {self.hit_ratio():.2%}, revalidated: {self.revalidated}, " 

148 + f"errors: {', '.join(f'{status_code}:{fails}' for status_code, fails in self.failed.items()) or '0'}, " 

149 + f"oldest cache hit: {self.max_cache_age:.2f}s, avg elapsed: {self.average_elapsed()}s" 

150 ) 

151 

152 

153class APIStatsCounter: 

154 def __init__(self) -> None: 

155 self.stats_report_interval: int = 100 

156 self.host_stats: dict[str, APIStats] = {} 

157 self.fetches: int = 0 

158 self.log: Any = structlog.get_logger().bind() 

159 

160 def stats(self, url: str, response: Response | None) -> None: 

161 try: 

162 host: str = urlparse(url).hostname or "UNKNOWN" 

163 api_stats: APIStats = self.host_stats.setdefault(host, APIStats()) 

164 api_stats.tick(response) 

165 self.fetches += 1 

166 if self.fetches % self.stats_report_interval == 0: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 self.log.info( 

168 "OCI_V2 API Stats Summary\n%s", "\n".join(f"{host} {stats}" for host, stats in self.host_stats.items()) 

169 ) 

170 except Exception as e: 

171 self.log.warning("Failed to tick stats: %s", e) 

172 

173 

174def fetch_url( 

175 url: str, 

176 cache_ttl: int | None = None, # default to server responses for cache ttl 

177 bearer_token: str | None = None, 

178 response_type: str | list[str] | None = None, 

179 follow_redirects: bool = False, 

180 allow_stale: bool = False, 

181 method: str = "GET", 

182 api_stats_counter: APIStatsCounter | None = None, 

183) -> Response | None: 

184 try: 

185 headers = [("cache-control", f"max-age={cache_ttl}")] 

186 if bearer_token: 

187 headers.append(("Authorization", f"Bearer {bearer_token}")) 

188 if response_type: 

189 response_type = [response_type] if isinstance(response_type, str) else response_type 

190 if response_type and isinstance(response_type, (tuple, list)): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true

191 headers.extend(("Accept", mime_type) for mime_type in response_type) 

192 

193 cache_policy = SpecificationPolicy( 

194 cache_options=CacheOptions( 

195 shared=False, # Private browser cache 

196 allow_stale=allow_stale, 

197 ) 

198 ) 

199 log_headers: list[tuple[str, str]] = [h for h in headers if len(h) > 1 and h[0] != "Authorization"] 

200 with SyncCacheClient(headers=headers, follow_redirects=follow_redirects, policy=cache_policy) as client: 

201 log.debug(f"Fetching URL {url}, redirects={follow_redirects}, headers={log_headers}, cache_ttl={cache_ttl}") 

202 response: Response = client.request(method=method, url=url, extensions={"hishel_ttl": cache_ttl}) 

203 cache_metadata: CacheMetadata = CacheMetadata(response) 

204 if not response.is_success: 

205 log.debug("URL %s fetch returned non-success status: %s, %s", url, response.status_code, cache_metadata.stored) 

206 elif response: 206 ↛ 215line 206 didn't jump to line 215 because the condition on line 206 was always true

207 log.debug( 

208 "URL response: status: %s, cached: %s, revalidated: %s, cache age: %s, stored: %s", 

209 response.status_code, 

210 cache_metadata.from_cache, 

211 cache_metadata.revalidated, 

212 cache_metadata.age, 

213 cache_metadata.stored, 

214 ) 

215 if api_stats_counter: 

216 api_stats_counter.stats(url, response) 

217 return response 

218 except Exception as e: 

219 log.debug("URL %s failed to fetch: %s", url, e) 

220 if api_stats_counter: 

221 api_stats_counter.stats(url, None) 

222 return None 

223 

224 

225def validate_url(url: str, cache_ttl: int = 1500) -> bool: 

226 response: Response | None = fetch_url(url, method="HEAD", cache_ttl=cache_ttl, follow_redirects=True) 

227 return response is not None and response.status_code != 404 

228 

229 

230def sanitize_name(name: str, replacement: str = "_", max_len: int = 64) -> str: 

231 """Strict sanitization that removes/replaces common problematic characters for MQTT or HA 

232 

233 - Replaces spaces with underscores 

234 - Removes control characters 

235 - Ensures alphanumeric safety for broader compatibility 

236 

237 Args: 

238 name: The topic component string to sanitize 

239 replacement: Character to replace invalid characters with (default: "_") 

240 max_len: Largest acceptable name size 

241 

242 Returns: 

243 Sanitized topic string safe for most MQTT brokers 

244 

245 """ 

246 if not name: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true

247 raise ValueError("Name cannot be empty") 

248 orig_name: str = name 

249 name = re.sub(r"[^A-Za-z0-9_\-\.]+", replacement, name) 

250 

251 # Replace multiple consecutive replacement chars with single one 

252 if replacement: 252 ↛ 257line 252 didn't jump to line 257 because the condition on line 252 was always true

253 pattern = re.escape(replacement) + "+" 

254 name = re.sub(pattern, replacement, name) 

255 

256 # Trim to max length 

257 topic_bytes = name.encode("utf-8") 

258 if len(topic_bytes) > max_len: 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true

259 name = topic_bytes[:max_len].decode("utf-8", errors="ignore") 

260 

261 if not name: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true

262 raise ValueError("Topic became empty after sanitization") 

263 if name != orig_name: 

264 log.info("Component name %s changed to %s for MQTT/HA compatibility", orig_name, name) 

265 

266 return name