Coverage for src / updates2mqtt / mqtt.py: 79%
283 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-06-02 10:03 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-06-02 10:03 +0000
1import asyncio
2import json
3import re
4import time
5from collections.abc import Callable
6from dataclasses import dataclass, field
7from threading import Event
8from typing import Any
10import paho.mqtt.client as mqtt
11import paho.mqtt.subscribeoptions
12import structlog
13from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage, MQTTMessageInfo
14from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion
15from paho.mqtt.properties import Properties
16from paho.mqtt.reasoncodes import ReasonCode
18from updates2mqtt.model import Discovery, ReleaseProvider
20from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy
21from .hass_formatter import hass_format_config, hass_format_state
23log = structlog.get_logger()
25MQTT_NAME = r"[A-Za-z0-9_\-\.]+"
28@dataclass
29class LocalMessage:
30 topic: str | None = field(default=None)
31 payload: str | None = field(default=None)
34class MqttPublisher:
35 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None:
36 self.cfg: MqttConfig = cfg
37 self.node_cfg: NodeConfig = node_cfg
38 self.hass_cfg: HomeAssistantConfig = hass_cfg
39 self.providers_by_topic: dict[str, ReleaseProvider] = {}
40 self.providers_by_type: dict[str, ReleaseProvider] = {}
41 self.event_loop: asyncio.AbstractEventLoop | None = None
42 self.client: mqtt.Client | None = None
43 self.fatal_failure = Event()
44 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt")
46 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None:
47 logger = self.log.bind(action="start")
48 try:
49 protocol: MQTTProtocolVersion
50 if self.cfg.protocol in ("3", "3.11"):
51 protocol = MQTTProtocolVersion.MQTTv311
52 elif self.cfg.protocol == "3.1":
53 protocol = MQTTProtocolVersion.MQTTv31
54 elif self.cfg.protocol in ("5", "5.0"):
55 protocol = MQTTProtocolVersion.MQTTv5
56 else:
57 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol)
58 protocol = MQTTProtocolVersion.MQTTv311
59 logger.debug("MQTT protocol set to %r", protocol)
61 self.event_loop = event_loop or asyncio.get_event_loop()
62 self.client = mqtt.Client(
63 callback_api_version=CallbackAPIVersion.VERSION2,
64 client_id=f"updates2mqtt_{self.node_cfg.name}",
65 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None,
66 protocol=protocol,
67 )
68 self.client.username_pw_set(self.cfg.user, password=self.cfg.password)
69 rc: MQTTErrorCode = self.client.connect(
70 host=self.cfg.host,
71 port=self.cfg.port,
72 keepalive=60,
73 clean_start=MQTT_CLEAN_START_FIRST_ONLY,
74 )
75 logger.info("Client connection requested", result_code=rc)
77 self.client.on_connect = self.on_connect
78 self.client.on_disconnect = self.on_disconnect
79 self.client.on_message = self.on_message
80 self.client.on_subscribe = self.on_subscribe
81 self.client.on_unsubscribe = self.on_unsubscribe
83 self.client.loop_start()
85 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port)
86 except Exception as e:
87 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e))
88 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e
90 def stop(self) -> None:
91 if self.client:
92 self.client.loop_stop()
93 self.client.disconnect()
94 self.client = None
96 def is_available(self) -> bool:
97 return self.client is not None and not self.fatal_failure.is_set()
99 def on_connect(
100 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None
101 ) -> None:
102 if not self.client or self.fatal_failure.is_set():
103 self.log.warn("No client, check if started and authorized")
104 return
105 if rc.getName() == "Not authorized":
106 self.fatal_failure.set()
107 log.error("Invalid MQTT credentials", result_code=rc)
108 return
109 if rc != 0:
110 self.log.warning("Connection failed to broker", result_code=rc)
111 else:
112 self.log.debug("Connected to broker", result_code=rc)
113 for topic, provider in self.providers_by_topic.items():
114 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type)
115 self.client.subscribe(topic)
117 def on_disconnect(
118 self,
119 _client: mqtt.Client,
120 _userdata: Any,
121 _disconnect_flags: mqtt.DisconnectFlags,
122 rc: ReasonCode,
123 _props: Properties | None,
124 ) -> None:
125 if rc == 0:
126 self.log.debug("Disconnected from broker", result_code=rc)
127 else:
128 self.log.warning("Disconnect failure from broker", result_code=rc)
130 async def clean_topics(
131 self, provider: ReleaseProvider, wait_time: int = 5, max_time: int = 120, initial: bool = False
132 ) -> None:
133 logger = self.log.bind(action="clean")
135 if self.fatal_failure.is_set():
136 return
137 try:
138 logger.info("Starting clean cycle, wait time: %s, max time: %s, initial: %s", wait_time, max_time, initial)
139 cutoff_time: float = time.time() + max_time
140 cleaner = mqtt.Client(
141 callback_api_version=CallbackAPIVersion.VERSION1,
142 client_id=f"updates2mqtt_clean_{self.node_cfg.name}",
143 clean_session=True,
144 )
145 results = {"cleaned": 0, "matched": 0, "discovered": 0, "last_timestamp": time.time()}
146 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password)
147 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=60)
149 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
150 discovery: Discovery | None = None
151 if msg.topic.startswith(
152 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_"
153 ):
154 discovery = self.reverse_config_topic(msg.topic, provider.source_type)
155 elif msg.topic.startswith( 155 ↛ 158line 155 didn't jump to line 158 because the condition on line 155 was never true
156 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"
157 ) and msg.topic.endswith("/state"):
158 discovery = self.reverse_state_topic(msg.topic, provider.source_type)
159 elif msg.topic.startswith(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"): 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true
160 discovery = self.reverse_general_topic(msg.topic, provider.source_type)
161 else:
162 logger.debug("Ignoring other topic ", topic=msg.topic)
163 return
164 results["discovered"] += 1
165 if not initial and discovery is None:
166 logger.debug("Removing unknown discovery", topic=msg.topic)
167 cleaner.publish(msg.topic, "", retain=True)
168 results["cleaned"] += 1
169 elif discovery is not None: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true
170 results["matched"] += 1
172 try:
173 if msg.payload: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 payload = json.loads(msg.payload)
175 if payload.get("in_progress") and initial:
176 cleaner.publish(msg.topic, "", retain=True)
177 results["cleaned"] += 1
178 except Exception as e:
179 logger.warn("Invalid payload at %s: %s", msg.topic, e)
180 cleaner.publish(msg.topic, "", retain=True)
181 results["cleaned"] += 1
183 results["last_timestamp"] = time.time()
185 cleaner.on_message = cleanup
186 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True)
187 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options)
188 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options)
190 while time.time() - results["last_timestamp"] <= wait_time and time.time() <= cutoff_time:
191 cleaner.loop(0.5)
193 logger.info(
194 f"Cleaned - discovered:{results['discovered']}, matched:{results['matched']}, cleaned:{results['cleaned']}"
195 )
196 except Exception as e:
197 logger.exception("Cleaning topics of stale entries failed: %s", e)
199 def safe_json_decode(self, jsonish: str | bytes | None) -> dict:
200 if jsonish is None:
201 return {}
202 try:
203 return json.loads(jsonish)
204 except Exception:
205 log.exception("JSON decode fail (%s)", jsonish)
206 try:
207 return json.loads(jsonish[1:-1])
208 except Exception:
209 log.exception("JSON decode fail (%s)", jsonish[1:-1])
210 return {}
212 async def execute_command(
213 self, msg: MQTTMessage | LocalMessage, on_update_start: Callable, on_update_end: Callable
214 ) -> None:
215 # TODO: defer handling of commands where repository is throttled
216 logger = self.log.bind(topic=msg.topic, payload=msg.payload)
217 comp_name: str | None = None
218 command: str | None = None
219 try:
220 logger.info("Execution starting")
221 source_type: str | None = None
223 payload: str | None = None
224 if isinstance(msg.payload, bytes):
225 payload = msg.payload.decode("utf-8")
226 elif isinstance(msg.payload, str): 226 ↛ 228line 226 didn't jump to line 228 because the condition on line 226 was always true
227 payload = msg.payload
228 if payload and "|" in payload:
229 source_type, comp_name, command = payload.split("|")
230 logger.debug("Executing %s:%s:%s", source_type, comp_name, command)
232 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None
233 if not provider: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true
234 logger.warn("Unexpected provider type %s", msg.topic)
235 elif provider.source_type != source_type:
236 logger.warn("Unexpected source type %s", source_type)
237 elif command != "install" or not comp_name: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true
238 logger.warn("Invalid payload in command message: %s", msg.payload)
239 else:
240 logger.info(
241 "Passing %s command to %s scanner for %s",
242 command,
243 source_type,
244 comp_name,
245 )
246 updated: bool = provider.command(comp_name, command, on_update_start, on_update_end)
247 discovery = provider.resolve(comp_name)
248 if updated and discovery: 248 ↛ 256line 248 didn't jump to line 256 because the condition on line 248 was always true
249 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled:
250 self.publish_hass_config(discovery)
251 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
252 self.publish_discovery(discovery)
253 if discovery and discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
254 self.publish_hass_state(discovery)
255 else:
256 logger.debug("No change to republish after execution")
257 logger.info("Execution ended")
258 except Exception:
259 logger.exception("Execution failed", component=comp_name, command=command)
261 def local_message(self, discovery: Discovery, command: str) -> None:
262 """Simulate an incoming MQTT message for local commands"""
263 msg = LocalMessage(
264 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command])
265 )
266 self.handle_message(msg)
268 def on_subscribe(
269 self,
270 _client: mqtt.Client,
271 userdata: Any,
272 mid: int,
273 reason_code_list: list[ReasonCode],
274 properties: Properties | None = None,
275 ) -> None:
276 self.log.debug(
277 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
278 )
280 def on_unsubscribe(
281 self,
282 _client: mqtt.Client,
283 userdata: Any,
284 mid: int,
285 reason_code_list: list[ReasonCode],
286 properties: Properties | None = None,
287 ) -> None:
288 self.log.debug(
289 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
290 )
292 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
293 """Callback for incoming MQTT messages""" # noqa: D401
294 if msg.topic in self.providers_by_topic:
295 self.handle_message(msg)
296 else:
297 # apparently the root non-wildcard sub sometimes brings in child topics
298 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload)
300 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None:
301 def update_start(discovery: Discovery) -> None:
302 self.log.debug("on_update_start: %s", topic=msg.topic)
303 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
304 self.publish_hass_state(discovery, in_progress=True)
305 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
306 self.publish_discovery(discovery, in_progress=True)
308 def update_end(discovery: Discovery) -> None:
309 self.log.debug("on_update_end: %s", topic=msg.topic)
310 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
311 self.publish_hass_state(discovery, in_progress=False)
312 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
313 self.publish_discovery(discovery, in_progress=False)
315 # TODO: fix double publish on callback and in command exec
316 if self.event_loop is not None: 316 ↛ 322line 316 didn't jump to line 322 because the condition on line 316 was always true
317 self.log.debug("Executing command topic", topic=msg.topic)
318 asyncio.run_coroutine_threadsafe(
319 self.execute_command(msg=msg, on_update_start=update_start, on_update_end=update_end), loop=self.event_loop
320 )
321 else:
322 self.log.error("No event loop to handle message", topic=msg.topic)
324 def config_topic(self, discovery: Discovery) -> str:
325 prefix = self.hass_cfg.discovery.prefix
326 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config"
328 def reverse_config_topic(self, topic: str, source_type: str) -> Discovery | None:
329 match = re.fullmatch(
330 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{source_type}_({MQTT_NAME})/update/config",
331 topic,
332 )
333 if match and len(match.groups()) == 1: 333 ↛ 338line 333 didn't jump to line 338 because the condition on line 333 was always true
334 discovery_name: str = match.group(1)
335 if source_type in self.providers_by_type and discovery_name in self.providers_by_type[source_type].discoveries:
336 return self.providers_by_type[source_type].discoveries[discovery_name]
338 self.log.debug("MQTT CONFIG no match for %s", topic)
339 return None
341 def state_topic(self, discovery: Discovery) -> str:
342 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state"
344 def reverse_state_topic(self, topic: str, source_type: str) -> Discovery | None:
345 match = re.fullmatch(
346 f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})/state",
347 topic,
348 )
349 if match and len(match.groups()) == 1:
350 discovery_name: str = match.group(1)
351 if discovery_name in self.providers_by_type[source_type].discoveries:
352 return self.providers_by_type[source_type].discoveries[discovery_name]
354 self.log.debug("MQTT STATE no match for %s", topic)
355 return None
357 def general_topic(self, discovery: Discovery) -> str:
358 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}"
360 def reverse_general_topic(self, topic: str, source_type: str) -> Discovery | None:
361 match = re.fullmatch(f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})", topic)
362 if match and len(match.groups()) == 1:
363 discovery_name: str = match.group(1)
364 if discovery_name in self.providers_by_type[source_type].discoveries:
365 return self.providers_by_type[source_type].discoveries[discovery_name]
367 self.log.debug("MQTT ATTR no match for %s", topic)
368 return None
370 def command_topic(self, provider: ReleaseProvider) -> str:
371 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}"
373 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None:
374 """Comprehensive, non Home Assistant specific, base publication"""
375 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true
376 return
377 self.log.debug("Discovery publish: %s", discovery)
378 payload: dict[str, Any] = discovery.as_dict()
379 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment]
380 if payload.get("release", {}).get("summary") and self.hass_cfg.release_summary_max_size: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true
381 payload["release"]["summary"] = payload["release"]["summary"][: self.hass_cfg.release_summary_max_size]
382 self.publish(self.general_topic(discovery), payload)
384 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None:
385 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true
386 return
387 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery)
388 self.publish(
389 self.state_topic(discovery),
390 hass_format_state(
391 discovery, in_progress=in_progress, release_summary_max_size=self.hass_cfg.release_summary_max_size
392 ),
393 )
395 def publish_hass_config(self, discovery: Discovery) -> None:
396 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 396 ↛ 397line 396 didn't jump to line 397 because the condition on line 396 was never true
397 return
398 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}"
399 self.log.debug("HASS Config: %s", object_id)
401 self.publish(
402 self.config_topic(discovery),
403 hass_format_config(
404 discovery=discovery,
405 object_id=object_id,
406 area=self.hass_cfg.area,
407 state_topic=self.state_topic(discovery),
408 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None,
409 command_topic=self.command_topic(discovery.provider),
410 force_command_topic=self.hass_cfg.force_command_topic,
411 device_creation=self.hass_cfg.device_creation,
412 ),
413 )
415 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201
416 topic = self.command_topic(provider)
417 if topic in self.providers_by_topic or self.client is None:
418 self.log.debug("Skipping subscription", topic=topic)
419 else:
420 self.log.info("Handler subscribing", topic=topic)
421 self.providers_by_topic[topic] = provider
422 self.providers_by_type[provider.source_type] = provider
423 self.client.subscribe(topic)
424 return topic
426 def loop_once(self) -> None:
427 if self.client:
428 self.client.loop()
430 def publish(self, topic: str, payload: dict, qos: int = 0, retain: bool = True) -> None:
431 if self.client:
432 info: MQTTMessageInfo = self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain)
433 if info.rc == MQTTErrorCode.MQTT_ERR_SUCCESS:
434 self.log.debug("Publish to %s, mid: %s, published: %s, rc: %s", topic, info.mid, info.is_published(), info.rc)
435 else:
436 self.log.warning("Problem publishing to %s, mid: %s, rc: %s", topic, info.mid, info.rc)
437 else:
438 self.log.debug("No client to publish at %s", topic)