Coverage for src/updates2mqtt/mqtt.py: 76%
329 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-14 15:07 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-14 15:07 +0000
1import asyncio
2import json
3import re
4import time
5from collections.abc import Callable
6from dataclasses import dataclass, field
7from threading import Event
8from typing import Any
10import paho.mqtt.client as mqtt
11import paho.mqtt.subscribeoptions
12import structlog
13from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage, MQTTMessageInfo
14from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion
15from paho.mqtt.properties import Properties
16from paho.mqtt.reasoncodes import ReasonCode
18from updates2mqtt.model import Discovery, ReleaseProvider
20from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy
21from .hass_formatter import hass_format_config, hass_format_state
23log = structlog.get_logger()
25MQTT_NAME = r"[A-Za-z0-9_\-\.]+"
28@dataclass
29class LocalMessage:
30 topic: str | None = field(default=None)
31 payload: str | None = field(default=None)
34class MqttPublisher:
35 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None:
36 self.cfg: MqttConfig = cfg
37 self.node_cfg: NodeConfig = node_cfg
38 self.hass_cfg: HomeAssistantConfig = hass_cfg
39 self.providers_by_topic: dict[str, ReleaseProvider] = {}
40 self.providers_by_type: dict[str, ReleaseProvider] = {}
41 self.event_loop: asyncio.AbstractEventLoop | None = None
42 self.client: mqtt.Client | None = None
43 self.fatal_failure = Event()
44 self.connected = Event()
45 self.commands_in_progress: set[tuple[str, str]] = set()
46 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt")
48 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None:
49 logger = self.log.bind(action="start")
50 try:
51 protocol: MQTTProtocolVersion
52 if self.cfg.protocol in ("3", "3.11"):
53 protocol = MQTTProtocolVersion.MQTTv311
54 elif self.cfg.protocol == "3.1":
55 protocol = MQTTProtocolVersion.MQTTv31
56 elif self.cfg.protocol in ("5", "5.0"):
57 protocol = MQTTProtocolVersion.MQTTv5
58 else:
59 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol)
60 protocol = MQTTProtocolVersion.MQTTv311
61 logger.debug("MQTT protocol set to %r", protocol)
63 self.event_loop = event_loop or asyncio.get_event_loop()
64 self.client = mqtt.Client(
65 callback_api_version=CallbackAPIVersion.VERSION2,
66 client_id=f"updates2mqtt_{self.node_cfg.name}",
67 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None,
68 protocol=protocol,
69 )
70 self.client.username_pw_set(self.cfg.user, password=self.cfg.password)
71 rc: MQTTErrorCode = self.client.connect(
72 host=self.cfg.host,
73 port=self.cfg.port,
74 keepalive=self.cfg.keepalive,
75 clean_start=MQTT_CLEAN_START_FIRST_ONLY,
76 )
77 logger.info("Client connection requested", result_code=rc)
79 self.client.on_connect = self.on_connect
80 self.client.on_disconnect = self.on_disconnect
81 self.client.on_message = self.on_message
82 self.client.on_subscribe = self.on_subscribe
83 self.client.on_unsubscribe = self.on_unsubscribe
85 self.client.loop_start()
87 if not self.connected.wait(timeout=self.cfg.connect_timeout) and not self.fatal_failure.is_set(): 87 ↛ 88line 87 didn't jump to line 88 because the condition on line 87 was never true
88 logger.warning("Timed out waiting for broker connection, continuing anyway", timeout=self.cfg.connect_timeout)
90 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port)
91 except Exception as e:
92 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e))
93 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e
95 def stop(self) -> None:
96 if self.client:
97 self.client.loop_stop()
98 self.client.disconnect()
99 self.client = None
100 self.connected.clear()
102 def is_available(self) -> bool:
103 return self.client is not None and not self.fatal_failure.is_set()
105 def on_connect(
106 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None
107 ) -> None:
108 if not self.client or self.fatal_failure.is_set():
109 self.log.warn("No client, check if started and authorized")
110 return
111 if rc.getName() == "Not authorized":
112 self.fatal_failure.set()
113 log.error("Invalid MQTT credentials", result_code=rc)
114 return
115 if rc != 0:
116 self.log.warning("Connection failed to broker", result_code=rc)
117 else:
118 self.log.debug("Connected successfully to MQTT broker")
119 self.connected.set()
120 for topic, provider in self.providers_by_topic.items():
121 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type)
122 self.client.subscribe(topic)
124 def on_disconnect(
125 self,
126 _client: mqtt.Client,
127 _userdata: Any,
128 _disconnect_flags: mqtt.DisconnectFlags,
129 rc: ReasonCode,
130 _props: Properties | None,
131 ) -> None:
132 self.connected.clear()
133 if rc == 0:
134 self.log.debug("Disconnected from broker", result_code=rc)
135 else:
136 self.log.warning("Disconnect failure from broker", result_code=rc)
138 async def clean_topics(
139 self, provider: ReleaseProvider, wait_time: int = 5, max_time: int = 120, initial: bool = False
140 ) -> None:
141 logger = self.log.bind(action="clean")
143 if self.fatal_failure.is_set():
144 return
145 try:
146 logger.info("Starting clean cycle, wait time: %s, max time: %s, initial: %s", wait_time, max_time, initial)
147 cutoff_time: float = time.time() + max_time
148 cleaner = mqtt.Client(
149 callback_api_version=CallbackAPIVersion.VERSION1,
150 client_id=f"updates2mqtt_clean_{self.node_cfg.name}",
151 clean_session=True,
152 )
153 results = {"cleaned": 0, "matched": 0, "discovered": 0, "last_timestamp": time.time()}
154 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password)
155 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=self.cfg.keepalive)
157 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
158 discovery: Discovery | None = None
159 if msg.topic.startswith(
160 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_"
161 ):
162 discovery = self.reverse_config_topic(msg.topic, provider.source_type)
163 elif msg.topic.startswith( 163 ↛ 166line 163 didn't jump to line 166 because the condition on line 163 was never true
164 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"
165 ) and msg.topic.endswith("/state"):
166 discovery = self.reverse_state_topic(msg.topic, provider.source_type)
167 elif msg.topic.startswith(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"): 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true
168 discovery = self.reverse_general_topic(msg.topic, provider.source_type)
169 else:
170 logger.debug("Ignoring other topic ", topic=msg.topic)
171 return
172 results["discovered"] += 1
173 if not initial and discovery is None:
174 logger.debug("Removing unknown discovery", topic=msg.topic)
175 cleaner.publish(msg.topic, "", retain=True)
176 results["cleaned"] += 1
177 elif discovery is not None: 177 ↛ 180line 177 didn't jump to line 180 because the condition on line 177 was always true
178 results["matched"] += 1
180 try:
181 if msg.payload: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true
182 payload = json.loads(msg.payload)
183 update_section = payload.get("update") if isinstance(payload.get("update"), dict) else None
184 lingering_in_progress = payload.get("in_progress") or (
185 update_section and update_section.get("in_progress")
186 )
187 if lingering_in_progress and initial:
188 logger.info("Clearing lingering in-progress state at %s", msg.topic)
189 payload["in_progress"] = False
190 if update_section is not None:
191 update_section["in_progress"] = False
192 cleaner.publish(msg.topic, json.dumps(payload), retain=True)
193 results["cleaned"] += 1
194 elif (
195 initial
196 and msg.topic.endswith("/state")
197 and (payload.get("installed_version") is None or payload.get("latest_version") is None)
198 ):
199 # Stale/incomplete state message (e.g. from an older schema) leaves HA showing
200 # "unknown" forever, since it has nothing to compare against. Clear it so the
201 # upcoming scan can publish a complete replacement.
202 logger.info("Clearing stale incomplete state at %s", msg.topic)
203 cleaner.publish(msg.topic, "", retain=True)
204 results["cleaned"] += 1
205 except Exception as e:
206 logger.warn("Invalid payload at %s: %s", msg.topic, e)
207 cleaner.publish(msg.topic, "", retain=True)
208 results["cleaned"] += 1
210 results["last_timestamp"] = time.time()
212 cleaner.on_message = cleanup
213 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True)
214 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options)
215 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options)
217 while time.time() - results["last_timestamp"] <= wait_time and time.time() <= cutoff_time:
218 cleaner.loop(0.5)
220 logger.info(
221 f"Cleaned - discovered:{results['discovered']}, matched:{results['matched']}, cleaned:{results['cleaned']}"
222 )
223 except Exception as e:
224 logger.exception("Cleaning topics of stale entries failed: %s", e)
226 def safe_json_decode(self, jsonish: str | bytes | None) -> dict:
227 if jsonish is None:
228 return {}
229 try:
230 return json.loads(jsonish)
231 except Exception:
232 log.exception("JSON decode fail (%s)", jsonish)
233 try:
234 return json.loads(jsonish[1:-1])
235 except Exception:
236 log.exception("JSON decode fail (%s)", jsonish[1:-1])
237 return {}
239 def validate_command(self, msg: MQTTMessage | LocalMessage) -> tuple[ReleaseProvider, str, str] | None:
241 logger = self.log.bind(topic=msg.topic, payload=msg.payload)
242 comp_name: str | None = None
243 command: str | None = None
244 try:
245 logger.info("Command received for %s", msg.topic)
246 source_type: str | None = None
248 payload: str | None = None
249 if isinstance(msg.payload, bytes):
250 payload = msg.payload.decode("utf-8")
251 elif isinstance(msg.payload, str): 251 ↛ 253line 251 didn't jump to line 253 because the condition on line 251 was always true
252 payload = msg.payload
253 if payload and "|" in payload:
254 source_type, comp_name, command = payload.split("|")
255 else:
256 logger.warn("Invalid command format, expecting `source_type|comp_name|command`")
257 return None
258 logger.debug("Validating %s:%s:%s", source_type, comp_name, command)
260 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None
262 if not provider: 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true
263 logger.warn("Unexpected provider type %s", msg.topic)
264 return None
265 if source_type is None or provider.source_type != source_type:
266 logger.warn("Unexpected source type %s", source_type)
267 return None
268 if command != "install": 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true
269 logger.warn("Unknown command: %s", command)
270 return None
271 if not comp_name: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 logger.warn("Missing comp_name in command message: %s", msg.payload)
273 return None
275 in_progress_key: tuple[str, str] = (source_type, comp_name)
276 if in_progress_key in self.commands_in_progress:
277 logger.warn("Ignoring duplicate %s command for %s, already in progress", command, comp_name)
278 else:
279 self.commands_in_progress.add(in_progress_key)
280 return (provider, comp_name, command)
281 except Exception:
282 logger.error("Unexpected error validating command")
283 return None
285 async def execute_command(
286 self, provider: ReleaseProvider, comp_name: str, command: str, on_update_start: Callable, on_update_end: Callable
287 ) -> None:
288 # TODO: defer handling of commands where repository is throttled
289 logger = self.log.bind(source_type=provider.source_type, comp_name=comp_name, command=command)
290 try:
291 logger.info("Execution starting for %s %s", command, comp_name)
293 in_progress_key: tuple[str, str] = (provider.source_type, comp_name)
294 logger.info(
295 "Passing %s command to %s scanner for %s",
296 command,
297 provider.source_type,
298 comp_name,
299 )
300 try:
301 updated: bool = provider.command(comp_name, command, on_update_start, on_update_end)
302 discovery = provider.resolve(comp_name)
303 if updated and discovery: 303 ↛ 311line 303 didn't jump to line 311 because the condition on line 303 was always true
304 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled:
305 self.publish_hass_config(discovery)
306 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
307 self.publish_discovery(discovery)
308 if discovery and discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
309 self.publish_hass_state(discovery)
310 else:
311 logger.debug("No change to republish after execution")
312 finally:
313 if in_progress_key in self.commands_in_progress: 313 ↛ 315line 313 didn't jump to line 315 because the condition on line 313 was always true
314 self.commands_in_progress.discard(in_progress_key)
315 logger.info("Execution ended")
316 except Exception:
317 logger.exception("Execution failed")
319 def local_message(self, discovery: Discovery, command: str) -> None:
320 """Simulate an incoming MQTT message for local commands"""
321 msg = LocalMessage(
322 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command])
323 )
324 self.handle_message(msg)
326 def on_subscribe(
327 self,
328 _client: mqtt.Client,
329 userdata: Any,
330 mid: int,
331 reason_code_list: list[ReasonCode],
332 properties: Properties | None = None,
333 ) -> None:
334 self.log.debug(
335 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
336 )
338 def on_unsubscribe(
339 self,
340 _client: mqtt.Client,
341 userdata: Any,
342 mid: int,
343 reason_code_list: list[ReasonCode],
344 properties: Properties | None = None,
345 ) -> None:
346 self.log.debug(
347 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
348 )
350 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
351 """Callback for incoming MQTT messages""" # noqa: D401
352 if msg.topic in self.providers_by_topic:
353 self.handle_message(msg)
354 else:
355 # apparently the root non-wildcard sub sometimes brings in child topics
356 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload)
358 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None:
359 def update_start(discovery: Discovery) -> None:
360 self.log.debug("on_update_start: %s", topic=msg.topic)
361 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
362 self.publish_hass_state(discovery, in_progress=True)
363 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
364 self.publish_discovery(discovery, in_progress=True)
366 def update_end(discovery: Discovery) -> None:
367 self.log.debug("on_update_end: %s", topic=msg.topic)
368 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
369 self.publish_hass_state(discovery, in_progress=False)
370 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
371 self.publish_discovery(discovery, in_progress=False)
373 # TODO: fix double publish on callback and in command exec
374 if self.event_loop is not None: 374 ↛ 390line 374 didn't jump to line 390 because the condition on line 374 was always true
375 self.log.debug("Executing command topic", topic=msg.topic)
376 parsed: tuple[ReleaseProvider, str, str] | None = self.validate_command(msg=msg)
377 if parsed is not None: 377 ↛ exitline 377 didn't return from function 'handle_message' because the condition on line 377 was always true
378 provider, comp_name, command = parsed
379 asyncio.run_coroutine_threadsafe(
380 self.execute_command(
381 provider=provider,
382 comp_name=comp_name,
383 command=command,
384 on_update_start=update_start,
385 on_update_end=update_end,
386 ),
387 loop=self.event_loop,
388 )
389 else:
390 self.log.error("No event loop to handle message", topic=msg.topic)
392 def config_topic(self, discovery: Discovery) -> str:
393 prefix = self.hass_cfg.discovery.prefix
394 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config"
396 def reverse_config_topic(self, topic: str, source_type: str) -> Discovery | None:
397 match = re.fullmatch(
398 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{source_type}_({MQTT_NAME})/update/config",
399 topic,
400 )
401 if match and len(match.groups()) == 1: 401 ↛ 406line 401 didn't jump to line 406 because the condition on line 401 was always true
402 discovery_name: str = match.group(1)
403 if source_type in self.providers_by_type and discovery_name in self.providers_by_type[source_type].discoveries:
404 return self.providers_by_type[source_type].discoveries[discovery_name]
406 self.log.debug("MQTT CONFIG no match for %s", topic)
407 return None
409 def state_topic(self, discovery: Discovery) -> str:
410 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state"
412 def reverse_state_topic(self, topic: str, source_type: str) -> Discovery | None:
413 match = re.fullmatch(
414 f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})/state",
415 topic,
416 )
417 if match and len(match.groups()) == 1:
418 discovery_name: str = match.group(1)
419 if discovery_name in self.providers_by_type[source_type].discoveries:
420 return self.providers_by_type[source_type].discoveries[discovery_name]
422 self.log.debug("MQTT STATE no match for %s", topic)
423 return None
425 def general_topic(self, discovery: Discovery) -> str:
426 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}"
428 def reverse_general_topic(self, topic: str, source_type: str) -> Discovery | None:
429 match = re.fullmatch(f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})", topic)
430 if match and len(match.groups()) == 1:
431 discovery_name: str = match.group(1)
432 if discovery_name in self.providers_by_type[source_type].discoveries:
433 return self.providers_by_type[source_type].discoveries[discovery_name]
435 self.log.debug("MQTT ATTR no match for %s", topic)
436 return None
438 def command_topic(self, provider: ReleaseProvider) -> str:
439 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}"
441 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None:
442 """Comprehensive, non Home Assistant specific, base publication"""
443 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 443 ↛ 444line 443 didn't jump to line 444 because the condition on line 443 was never true
444 return
445 self.log.debug("Discovery publish: %s", discovery)
446 payload: dict[str, Any] = discovery.as_dict()
447 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment]
448 if payload.get("release", {}).get("summary") and self.hass_cfg.release_summary_max_size: 448 ↛ 449line 448 didn't jump to line 449 because the condition on line 448 was never true
449 payload["release"]["summary"] = payload["release"]["summary"][: self.hass_cfg.release_summary_max_size]
450 self.publish(self.general_topic(discovery), payload)
452 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None:
453 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 453 ↛ 454line 453 didn't jump to line 454 because the condition on line 453 was never true
454 return
455 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery)
456 self.publish(
457 self.state_topic(discovery),
458 hass_format_state(
459 discovery, in_progress=in_progress, release_summary_max_size=self.hass_cfg.release_summary_max_size
460 ),
461 )
463 def publish_hass_config(self, discovery: Discovery) -> None:
464 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 464 ↛ 465line 464 didn't jump to line 465 because the condition on line 464 was never true
465 return
466 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}"
467 self.log.debug("HASS Config: %s", object_id)
469 self.publish(
470 self.config_topic(discovery),
471 hass_format_config(
472 discovery=discovery,
473 object_id=object_id,
474 area=self.hass_cfg.area,
475 state_topic=self.state_topic(discovery),
476 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None,
477 command_topic=self.command_topic(discovery.provider),
478 force_command_topic=self.hass_cfg.force_command_topic,
479 device_creation=self.hass_cfg.device_creation,
480 ),
481 )
483 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201
484 topic = self.command_topic(provider)
485 if topic in self.providers_by_topic or self.client is None:
486 self.log.debug("Skipping subscription", topic=topic)
487 else:
488 self.log.info("Handler subscribing", topic=topic)
489 self.providers_by_topic[topic] = provider
490 self.providers_by_type[provider.source_type] = provider
491 self.client.subscribe(topic)
492 return topic
494 def loop_once(self) -> None:
495 if self.client:
496 self.client.loop()
498 def publish(self, topic: str, payload: dict, qos: int = 1, retain: bool = True) -> None:
499 if self.client:
500 info: MQTTMessageInfo = self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain)
501 if info.rc == MQTTErrorCode.MQTT_ERR_SUCCESS:
502 self.log.debug(
503 "Publish to %s, mid: %s, published: %s, qos: %s, rc: %s", topic, info.mid, info.is_published(), qos, info.rc
504 )
505 elif info.rc == MQTTErrorCode.MQTT_ERR_NO_CONN and qos > 0: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true
506 self.log.debug(
507 "Not currently connected, queued for delivery on reconnect: %s, mid: %s, qos: %s",
508 topic,
509 info.mid,
510 qos,
511 )
512 else:
513 self.log.warning("Problem publishing to %s, mid: %s, qos: %s, rc: %s", topic, info.mid, qos, info.rc)
514 else:
515 self.log.debug("No client to publish at %s", topic)