Coverage for src / updates2mqtt / mqtt.py: 76%
283 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-20 23:13 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-20 23:13 +0000
1import asyncio
2import json
3import re
4import time
5from collections.abc import Callable
6from dataclasses import dataclass, field
7from threading import Event
8from typing import Any
10import paho.mqtt.client as mqtt
11import paho.mqtt.subscribeoptions
12import structlog
13from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage, MQTTMessageInfo
14from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion
15from paho.mqtt.properties import Properties
16from paho.mqtt.reasoncodes import ReasonCode
18from updates2mqtt.model import Discovery, ReleaseProvider
20from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy
21from .hass_formatter import hass_format_config, hass_format_state
23log = structlog.get_logger()
25MQTT_NAME = r"[A-Za-z0-9_\-\.]+"
28@dataclass
29class LocalMessage:
30 topic: str | None = field(default=None)
31 payload: str | None = field(default=None)
34class MqttPublisher:
35 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None:
36 self.cfg: MqttConfig = cfg
37 self.node_cfg: NodeConfig = node_cfg
38 self.hass_cfg: HomeAssistantConfig = hass_cfg
39 self.providers_by_topic: dict[str, ReleaseProvider] = {}
40 self.providers_by_type: dict[str, ReleaseProvider] = {}
41 self.event_loop: asyncio.AbstractEventLoop | None = None
42 self.client: mqtt.Client | None = None
43 self.fatal_failure = Event()
44 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt")
46 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None:
47 logger = self.log.bind(action="start")
48 try:
49 protocol: MQTTProtocolVersion
50 if self.cfg.protocol in ("3", "3.11"):
51 protocol = MQTTProtocolVersion.MQTTv311
52 elif self.cfg.protocol == "3.1":
53 protocol = MQTTProtocolVersion.MQTTv31
54 elif self.cfg.protocol in ("5", "5.0"):
55 protocol = MQTTProtocolVersion.MQTTv5
56 else:
57 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol)
58 protocol = MQTTProtocolVersion.MQTTv311
59 logger.debug("MQTT protocol set to %r", protocol)
61 self.event_loop = event_loop or asyncio.get_event_loop()
62 self.client = mqtt.Client(
63 callback_api_version=CallbackAPIVersion.VERSION2,
64 client_id=f"updates2mqtt_{self.node_cfg.name}",
65 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None,
66 protocol=protocol,
67 )
68 self.client.username_pw_set(self.cfg.user, password=self.cfg.password)
69 rc: MQTTErrorCode = self.client.connect(
70 host=self.cfg.host,
71 port=self.cfg.port,
72 keepalive=60,
73 clean_start=MQTT_CLEAN_START_FIRST_ONLY,
74 )
75 logger.info("Client connection requested", result_code=rc)
77 self.client.on_connect = self.on_connect
78 self.client.on_disconnect = self.on_disconnect
79 self.client.on_message = self.on_message
80 self.client.on_subscribe = self.on_subscribe
81 self.client.on_unsubscribe = self.on_unsubscribe
83 self.client.loop_start()
85 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port)
86 except Exception as e:
87 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e))
88 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e
90 def stop(self) -> None:
91 if self.client: 91 ↛ exitline 91 didn't return from function 'stop' because the condition on line 91 was always true
92 self.client.loop_stop()
93 self.client.disconnect()
94 self.client = None
96 def is_available(self) -> bool:
97 return self.client is not None and not self.fatal_failure.is_set()
99 def on_connect(
100 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None
101 ) -> None:
102 if not self.client or self.fatal_failure.is_set(): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 self.log.warn("No client, check if started and authorized")
104 return
105 if rc.getName() == "Not authorized":
106 self.fatal_failure.set()
107 log.error("Invalid MQTT credentials", result_code=rc)
108 return
109 if rc != 0: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true
110 self.log.warning("Connection failed to broker", result_code=rc)
111 else:
112 self.log.debug("Connected to broker", result_code=rc)
113 for topic, provider in self.providers_by_topic.items():
114 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type)
115 self.client.subscribe(topic)
117 def on_disconnect(
118 self,
119 _client: mqtt.Client,
120 _userdata: Any,
121 _disconnect_flags: mqtt.DisconnectFlags,
122 rc: ReasonCode,
123 _props: Properties | None,
124 ) -> None:
125 if rc == 0: 125 ↛ 128line 125 didn't jump to line 128 because the condition on line 125 was always true
126 self.log.debug("Disconnected from broker", result_code=rc)
127 else:
128 self.log.warning("Disconnect failure from broker", result_code=rc)
130 async def clean_topics(
131 self, provider: ReleaseProvider, wait_time: int = 5, max_time: int = 120, initial: bool = False
132 ) -> None:
133 logger = self.log.bind(action="clean")
134 if self.fatal_failure.is_set():
135 return
136 try:
137 logger.info("Starting clean cycle, wait time: %s, max time: %s, initial: %s", wait_time, max_time, initial)
138 cutoff_time: float = time.time() + max_time
139 cleaner = mqtt.Client(
140 callback_api_version=CallbackAPIVersion.VERSION1,
141 client_id=f"updates2mqtt_clean_{self.node_cfg.name}",
142 clean_session=True,
143 )
144 results = {"cleaned": 0, "matched": 0, "discovered": 0, "last_timestamp": time.time()}
145 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password)
146 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=60)
148 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
149 discovery: Discovery | None = None
150 if msg.topic.startswith(
151 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_"
152 ):
153 discovery = self.reverse_config_topic(msg.topic, provider.source_type)
154 elif msg.topic.startswith( 154 ↛ 157line 154 didn't jump to line 157 because the condition on line 154 was never true
155 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"
156 ) and msg.topic.endswith("/state"):
157 discovery = self.reverse_state_topic(msg.topic, provider.source_type)
158 elif msg.topic.startswith(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"): 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 discovery = self.reverse_general_topic(msg.topic, provider.source_type)
160 else:
161 logger.debug("Ignoring other topic ", topic=msg.topic)
162 return
163 results["discovered"] += 1
164 if not initial and discovery is None:
165 logger.debug("Removing unknown discovery", topic=msg.topic)
166 cleaner.publish(msg.topic, "", retain=True)
167 results["cleaned"] += 1
168 elif discovery is not None: 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was always true
169 results["matched"] += 1
171 try:
172 if msg.payload: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true
173 payload = json.loads(msg.payload)
174 if payload.get("in_progress") and initial:
175 cleaner.publish(msg.topic, "", retain=True)
176 results["cleaned"] += 1
177 except Exception as e:
178 logger.warn("Invalid payload at %s: %s", msg.topic, e)
179 cleaner.publish(msg.topic, "", retain=True)
180 results["cleaned"] += 1
182 results["last_timestamp"] = time.time()
184 cleaner.on_message = cleanup
185 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True)
186 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options)
187 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options)
189 while time.time() - results["last_timestamp"] <= wait_time and time.time() <= cutoff_time:
190 cleaner.loop(0.5)
192 logger.info(
193 f"Cleaned - discovered:{results['discovered']}, matched:{results['matched']}, cleaned:{results['cleaned']}"
194 )
195 except Exception as e:
196 logger.exception("Cleaning topics of stale entries failed: %s", e)
198 def safe_json_decode(self, jsonish: str | bytes | None) -> dict:
199 if jsonish is None:
200 return {}
201 try:
202 return json.loads(jsonish)
203 except Exception:
204 log.exception("JSON decode fail (%s)", jsonish)
205 try:
206 return json.loads(jsonish[1:-1])
207 except Exception:
208 log.exception("JSON decode fail (%s)", jsonish[1:-1])
209 return {}
211 async def execute_command(
212 self, msg: MQTTMessage | LocalMessage, on_update_start: Callable, on_update_end: Callable
213 ) -> None:
214 # TODO: defer handling of commands where repository is throttled
215 logger = self.log.bind(topic=msg.topic, payload=msg.payload)
216 comp_name: str | None = None
217 command: str | None = None
218 try:
219 logger.info("Execution starting")
220 source_type: str | None = None
222 payload: str | None = None
223 if isinstance(msg.payload, bytes):
224 payload = msg.payload.decode("utf-8")
225 elif isinstance(msg.payload, str): 225 ↛ 227line 225 didn't jump to line 227 because the condition on line 225 was always true
226 payload = msg.payload
227 if payload and "|" in payload:
228 source_type, comp_name, command = payload.split("|")
229 logger.debug("Executing %s:%s:%s", source_type, comp_name, command)
231 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None
232 if not provider: 232 ↛ 233line 232 didn't jump to line 233 because the condition on line 232 was never true
233 logger.warn("Unexpected provider type %s", msg.topic)
234 elif provider.source_type != source_type:
235 logger.warn("Unexpected source type %s", source_type)
236 elif command != "install" or not comp_name: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true
237 logger.warn("Invalid payload in command message: %s", msg.payload)
238 else:
239 logger.info(
240 "Passing %s command to %s scanner for %s",
241 command,
242 source_type,
243 comp_name,
244 )
245 updated: bool = provider.command(comp_name, command, on_update_start, on_update_end)
246 discovery = provider.resolve(comp_name)
247 if updated and discovery: 247 ↛ 255line 247 didn't jump to line 255 because the condition on line 247 was always true
248 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled:
249 self.publish_hass_config(discovery)
250 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
251 self.publish_discovery(discovery)
252 if discovery and discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
253 self.publish_hass_state(discovery)
254 else:
255 logger.debug("No change to republish after execution")
256 logger.info("Execution ended")
257 except Exception:
258 logger.exception("Execution failed", component=comp_name, command=command)
260 def local_message(self, discovery: Discovery, command: str) -> None:
261 """Simulate an incoming MQTT message for local commands"""
262 msg = LocalMessage(
263 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command])
264 )
265 self.handle_message(msg)
267 def on_subscribe(
268 self,
269 _client: mqtt.Client,
270 userdata: Any,
271 mid: int,
272 reason_code_list: list[ReasonCode],
273 properties: Properties | None = None,
274 ) -> None:
275 self.log.debug(
276 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
277 )
279 def on_unsubscribe(
280 self,
281 _client: mqtt.Client,
282 userdata: Any,
283 mid: int,
284 reason_code_list: list[ReasonCode],
285 properties: Properties | None = None,
286 ) -> None:
287 self.log.debug(
288 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
289 )
291 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
292 """Callback for incoming MQTT messages""" # noqa: D401
293 if msg.topic in self.providers_by_topic:
294 self.handle_message(msg)
295 else:
296 # apparently the root non-wildcard sub sometimes brings in child topics
297 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload)
299 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None:
300 def update_start(discovery: Discovery) -> None:
301 self.log.debug("on_update_start: %s", topic=msg.topic)
302 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
303 self.publish_hass_state(discovery, in_progress=True)
304 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
305 self.publish_discovery(discovery, in_progress=True)
307 def update_end(discovery: Discovery) -> None:
308 self.log.debug("on_update_end: %s", topic=msg.topic)
309 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
310 self.publish_hass_state(discovery, in_progress=False)
311 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
312 self.publish_discovery(discovery, in_progress=False)
314 # TODO: fix double publish on callback and in command exec
315 if self.event_loop is not None: 315 ↛ 321line 315 didn't jump to line 321 because the condition on line 315 was always true
316 self.log.debug("Executing command topic", topic=msg.topic)
317 asyncio.run_coroutine_threadsafe(
318 self.execute_command(msg=msg, on_update_start=update_start, on_update_end=update_end), loop=self.event_loop
319 )
320 else:
321 self.log.error("No event loop to handle message", topic=msg.topic)
323 def config_topic(self, discovery: Discovery) -> str:
324 prefix = self.hass_cfg.discovery.prefix
325 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config"
327 def reverse_config_topic(self, topic: str, source_type: str) -> Discovery | None:
328 match = re.fullmatch(
329 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{source_type}_({MQTT_NAME})/update/config",
330 topic,
331 )
332 if match and len(match.groups()) == 1: 332 ↛ 337line 332 didn't jump to line 337 because the condition on line 332 was always true
333 discovery_name: str = match.group(1)
334 if source_type in self.providers_by_type and discovery_name in self.providers_by_type[source_type].discoveries:
335 return self.providers_by_type[source_type].discoveries[discovery_name]
337 self.log.debug("MQTT CONFIG no match for %s", topic)
338 return None
340 def state_topic(self, discovery: Discovery) -> str:
341 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state"
343 def reverse_state_topic(self, topic: str, source_type: str) -> Discovery | None:
344 match = re.fullmatch(
345 f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})/state",
346 topic,
347 )
348 if match and len(match.groups()) == 1:
349 discovery_name: str = match.group(1)
350 if discovery_name in self.providers_by_type[source_type].discoveries:
351 return self.providers_by_type[source_type].discoveries[discovery_name]
353 self.log.debug("MQTT STATE no match for %s", topic)
354 return None
356 def general_topic(self, discovery: Discovery) -> str:
357 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}"
359 def reverse_general_topic(self, topic: str, source_type: str) -> Discovery | None:
360 match = re.fullmatch(f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})", topic)
361 if match and len(match.groups()) == 1:
362 discovery_name: str = match.group(1)
363 if discovery_name in self.providers_by_type[source_type].discoveries:
364 return self.providers_by_type[source_type].discoveries[discovery_name]
366 self.log.debug("MQTT ATTR no match for %s", topic)
367 return None
369 def command_topic(self, provider: ReleaseProvider) -> str:
370 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}"
372 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None:
373 """Comprehensive, non Home Assistant specific, base publication"""
374 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true
375 return
376 self.log.debug("Discovery publish: %s", discovery)
377 payload: dict[str, Any] = discovery.as_dict()
378 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment]
379 if payload.get("release", {}).get("summary") and self.hass_cfg.release_summary_max_size: 379 ↛ 380line 379 didn't jump to line 380 because the condition on line 379 was never true
380 payload["release"]["summary"] = payload["release"]["summary"][: self.hass_cfg.release_summary_max_size]
381 self.publish(self.general_topic(discovery), payload)
383 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None:
384 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true
385 return
386 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery)
387 self.publish(
388 self.state_topic(discovery),
389 hass_format_state(
390 discovery, in_progress=in_progress, release_summary_max_size=self.hass_cfg.release_summary_max_size
391 ),
392 )
394 def publish_hass_config(self, discovery: Discovery) -> None:
395 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true
396 return
397 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}"
398 self.log.debug("HASS Config: %s", object_id)
400 self.publish(
401 self.config_topic(discovery),
402 hass_format_config(
403 discovery=discovery,
404 object_id=object_id,
405 area=self.hass_cfg.area,
406 state_topic=self.state_topic(discovery),
407 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None,
408 command_topic=self.command_topic(discovery.provider),
409 force_command_topic=self.hass_cfg.force_command_topic,
410 device_creation=self.hass_cfg.device_creation,
411 ),
412 )
414 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201
415 topic = self.command_topic(provider)
416 if topic in self.providers_by_topic or self.client is None:
417 self.log.debug("Skipping subscription", topic=topic)
418 else:
419 self.log.info("Handler subscribing", topic=topic)
420 self.providers_by_topic[topic] = provider
421 self.providers_by_type[provider.source_type] = provider
422 self.client.subscribe(topic)
423 return topic
425 def loop_once(self) -> None:
426 if self.client:
427 self.client.loop()
429 def publish(self, topic: str, payload: dict, qos: int = 0, retain: bool = True) -> None:
430 if self.client:
431 info: MQTTMessageInfo = self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain)
432 if info.rc == MQTTErrorCode.MQTT_ERR_SUCCESS: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true
433 self.log.debug("Publish to %s, mid: %s, published: %s, rc: %s", topic, info.mid, info.is_published(), info.rc)
434 else:
435 self.log.warning("Problem publishing to %s, mid: %s, rc: %s", topic, info.mid, info.rc)
436 else:
437 self.log.debug("No client to publish at %s", topic)