Coverage for src / updates2mqtt / helpers.py: 88%
172 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-03 23:58 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-03 23:58 +0000
1import datetime as dt
2import re
3import time
4from threading import Event
5from typing import Any
6from urllib.parse import urlparse
8import structlog
9from hishel import CacheOptions, SpecificationPolicy # pyright: ignore[reportAttributeAccessIssue]
10from hishel.httpx import SyncCacheClient
11from httpx import Response
12from tzlocal import get_localzone
14from updates2mqtt.config import Selector
16log = structlog.get_logger()
19def timestamp(time_value: float | None) -> str | None:
20 if time_value is None:
21 return None
22 try:
23 return dt.datetime.fromtimestamp(time_value, tz=get_localzone()).isoformat()
24 except: # noqa: E722
25 return None
28class Selection:
29 def __init__(self, selector: Selector, value: str | None) -> None:
30 self.result: bool = True
31 self.matched: str | None = None
32 if value is None:
33 self.result = selector.include is None
34 return
35 if selector.exclude is not None:
36 self.result = True
37 if any(re.search(pat, value) for pat in selector.exclude):
38 self.matched = value
39 self.result = False
40 if selector.include is not None:
41 self.result = False
42 if any(re.search(pat, value) for pat in selector.include):
43 self.matched = value
44 self.result = True
46 def __bool__(self) -> bool:
47 """Expose the actual boolean so objects can be appropriately truthy"""
48 return self.result
51class ThrottledError(Exception):
52 def __init__(self, message: str, retry_secs: int) -> None:
53 super().__init__(message)
54 self.retry_secs = retry_secs
57class Throttler:
58 DEFAULT_SITE = "DEFAULT_SITE"
60 def __init__(self, api_throttle_pause: int = 30, logger: Any | None = None, semaphore: Event | None = None) -> None:
61 self.log: Any = logger or log
62 self.pause_api_until: dict[str, float] = {}
63 self.api_throttle_pause: int = api_throttle_pause
64 self.semaphore = semaphore
66 def check_throttle(self, index_name: str | None = None) -> bool:
67 if self.semaphore and self.semaphore.is_set(): 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 return True
69 index_name = index_name or self.DEFAULT_SITE
70 if self.pause_api_until.get(index_name) is not None:
71 if self.pause_api_until[index_name] < time.time():
72 del self.pause_api_until[index_name]
73 self.log.info("%s throttling wait complete", index_name)
74 else:
75 self.log.debug("%s throttling has %0.3f secs left", index_name, self.pause_api_until[index_name] - time.time())
76 return True
77 return False
79 def throttle(
80 self,
81 index_name: str | None = None,
82 retry_secs: int | None = None,
83 explanation: str | None = None,
84 raise_exception: bool = False,
85 ) -> None:
86 index_name = index_name or self.DEFAULT_SITE
87 retry_secs = retry_secs if retry_secs and retry_secs > 0 else self.api_throttle_pause
88 self.log.warn("%s throttling requests for %s seconds, %s", index_name, retry_secs, explanation)
89 self.pause_api_until[index_name] = time.time() + retry_secs
90 if raise_exception: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 raise ThrottledError(explanation or f"{index_name} throttled request", retry_secs)
94class CacheMetadata:
95 """Cache metadata extracted from hishel response extensions"""
97 def __init__(self, response: Response) -> None:
98 self.from_cache: bool = response.extensions.get("hishel_from_cache", False)
99 self.revalidated: bool = response.extensions.get("hishel_revalidated", False)
100 self.created_at: float | None = response.extensions.get("hishel_created_at")
101 self.stored: bool = response.extensions.get("hishel_stored", False)
102 self.age: float | None = None
103 if self.created_at is not None:
104 self.age = time.time() - self.created_at
106 def __str__(self) -> str:
107 """Summarize in a string"""
108 return f"cached: {self.from_cache}, revalidated: {self.revalidated}, age:{self.age}, stored:{self.stored}"
111class APIStats:
112 def __init__(self) -> None:
113 self.fetches: int = 0
114 self.cached: int = 0
115 self.revalidated: int = 0
116 self.failed: dict[int, int] = {}
117 self.elapsed: float = 0
118 self.max_cache_age: float = 0
120 def tick(self, response: Response | None) -> None:
121 self.fetches += 1
122 if response is None:
123 self.failed.setdefault(0, 0)
124 self.failed[0] += 1
125 return
126 cache_metadata: CacheMetadata = CacheMetadata(response)
127 self.cached += 1 if cache_metadata.from_cache else 0
128 self.revalidated += 1 if cache_metadata.revalidated else 0
129 if response.elapsed: 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true
130 self.elapsed += response.elapsed.microseconds / 1000000
131 self.elapsed += response.elapsed.seconds
132 if not response.is_success:
133 self.failed.setdefault(response.status_code, 0)
134 self.failed[response.status_code] += 1
135 if cache_metadata.age is not None and (self.max_cache_age is None or cache_metadata.age > self.max_cache_age):
136 self.max_cache_age = cache_metadata.age
138 def hit_ratio(self) -> float:
139 return round(self.cached / self.fetches, 2) if self.cached and self.fetches else 0
141 def average_elapsed(self) -> float:
142 return round(self.elapsed / self.fetches, 2) if self.elapsed and self.fetches else 0
144 def __str__(self) -> str:
145 """Log line friendly string summary"""
146 return (
147 f"fetches: {self.fetches}, cache ratio: {self.hit_ratio():.2%}, revalidated: {self.revalidated}, "
148 + f"errors: {', '.join(f'{status_code}:{fails}' for status_code, fails in self.failed.items()) or '0'}, "
149 + f"oldest cache hit: {self.max_cache_age:.2f}s, avg elapsed: {self.average_elapsed()}s"
150 )
153class APIStatsCounter:
154 def __init__(self) -> None:
155 self.stats_report_interval: int = 100
156 self.host_stats: dict[str, APIStats] = {}
157 self.fetches: int = 0
158 self.log: Any = structlog.get_logger().bind()
160 def stats(self, url: str, response: Response | None) -> None:
161 try:
162 host: str = urlparse(url).hostname or "UNKNOWN"
163 api_stats: APIStats = self.host_stats.setdefault(host, APIStats())
164 api_stats.tick(response)
165 self.fetches += 1
166 if self.fetches % self.stats_report_interval == 0: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 self.log.info(
168 "OCI_V2 API Stats Summary\n%s", "\n".join(f"{host} {stats}" for host, stats in self.host_stats.items())
169 )
170 except Exception as e:
171 self.log.warning("Failed to tick stats: %s", e)
174def fetch_url(
175 url: str,
176 cache_ttl: int | None = None, # default to server responses for cache ttl
177 bearer_token: str | None = None,
178 response_type: str | list[str] | None = None,
179 follow_redirects: bool = False,
180 allow_stale: bool = False,
181 method: str = "GET",
182 api_stats_counter: APIStatsCounter | None = None,
183) -> Response | None:
184 try:
185 headers = [("cache-control", f"max-age={cache_ttl}")]
186 if bearer_token:
187 headers.append(("Authorization", f"Bearer {bearer_token}"))
188 if response_type:
189 response_type = [response_type] if isinstance(response_type, str) else response_type
190 if response_type and isinstance(response_type, (tuple, list)): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true
191 headers.extend(("Accept", mime_type) for mime_type in response_type)
193 cache_policy = SpecificationPolicy(
194 cache_options=CacheOptions(
195 shared=False, # Private browser cache
196 allow_stale=allow_stale,
197 )
198 )
199 log_headers: list[tuple[str, str]] = [h for h in headers if len(h) > 1 and h[0] != "Authorization"]
200 with SyncCacheClient(headers=headers, follow_redirects=follow_redirects, policy=cache_policy) as client:
201 log.debug(f"Fetching URL {url}, redirects={follow_redirects}, headers={log_headers}, cache_ttl={cache_ttl}")
202 response: Response = client.request(method=method, url=url, extensions={"hishel_ttl": cache_ttl})
203 cache_metadata: CacheMetadata = CacheMetadata(response)
204 if not response.is_success:
205 log.debug("URL %s fetch returned non-success status: %s, %s", url, response.status_code, cache_metadata.stored)
206 elif response: 206 ↛ 215line 206 didn't jump to line 215 because the condition on line 206 was always true
207 log.debug(
208 "URL response: status: %s, cached: %s, revalidated: %s, cache age: %s, stored: %s",
209 response.status_code,
210 cache_metadata.from_cache,
211 cache_metadata.revalidated,
212 cache_metadata.age,
213 cache_metadata.stored,
214 )
215 if api_stats_counter:
216 api_stats_counter.stats(url, response)
217 return response
218 except Exception as e:
219 log.debug("URL %s failed to fetch: %s", url, e)
220 if api_stats_counter:
221 api_stats_counter.stats(url, None)
222 return None
225def validate_url(url: str, cache_ttl: int = 1500) -> bool:
226 response: Response | None = fetch_url(url, method="HEAD", cache_ttl=cache_ttl, follow_redirects=True)
227 return response is not None and response.status_code != 404
230def sanitize_name(name: str, replacement: str = "_", max_len: int = 64) -> str:
231 """Strict sanitization that removes/replaces common problematic characters for MQTT or HA
233 - Replaces spaces with underscores
234 - Removes control characters
235 - Ensures alphanumeric safety for broader compatibility
237 Args:
238 name: The topic component string to sanitize
239 replacement: Character to replace invalid characters with (default: "_")
240 max_len: Largest acceptable name size
242 Returns:
243 Sanitized topic string safe for most MQTT brokers
245 """
246 if not name: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true
247 raise ValueError("Name cannot be empty")
248 orig_name: str = name
249 name = re.sub(r"[^A-Za-z0-9_\-\.]+", replacement, name)
251 # Replace multiple consecutive replacement chars with single one
252 if replacement: 252 ↛ 257line 252 didn't jump to line 257 because the condition on line 252 was always true
253 pattern = re.escape(replacement) + "+"
254 name = re.sub(pattern, replacement, name)
256 # Trim to max length
257 topic_bytes = name.encode("utf-8")
258 if len(topic_bytes) > max_len: 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true
259 name = topic_bytes[:max_len].decode("utf-8", errors="ignore")
261 if not name: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true
262 raise ValueError("Topic became empty after sanitization")
263 if name != orig_name:
264 log.info("Component name %s changed to %s for MQTT/HA compatibility", orig_name, name)
266 return name