Coverage for src / updates2mqtt / mqtt.py: 88%
231 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-20 02:29 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-20 02:29 +0000
1import asyncio
2import json
3import time
4from collections.abc import Callable
5from dataclasses import dataclass, field
6from threading import Event
7from typing import Any
9import paho.mqtt.client as mqtt
10import paho.mqtt.subscribeoptions
11import structlog
12from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage
13from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion
14from paho.mqtt.properties import Properties
15from paho.mqtt.reasoncodes import ReasonCode
17from updates2mqtt.model import Discovery, ReleaseProvider
19from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy
20from .hass_formatter import hass_format_config, hass_format_state
22log = structlog.get_logger()
25@dataclass
26class LocalMessage:
27 topic: str | None = field(default=None)
28 payload: str | None = field(default=None)
31class MqttPublisher:
32 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None:
33 self.cfg: MqttConfig = cfg
34 self.node_cfg: NodeConfig = node_cfg
35 self.hass_cfg: HomeAssistantConfig = hass_cfg
36 self.providers_by_topic: dict[str, ReleaseProvider] = {}
37 self.event_loop: asyncio.AbstractEventLoop | None = None
38 self.client: mqtt.Client | None = None
39 self.fatal_failure = Event()
40 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt")
42 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None:
43 logger = self.log.bind(action="start")
44 try:
45 protocol: MQTTProtocolVersion
46 if self.cfg.protocol in ("3", "3.11"):
47 protocol = MQTTProtocolVersion.MQTTv311
48 elif self.cfg.protocol == "3.1":
49 protocol = MQTTProtocolVersion.MQTTv31
50 elif self.cfg.protocol in ("5", "5.0"):
51 protocol = MQTTProtocolVersion.MQTTv5
52 else:
53 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol)
54 protocol = MQTTProtocolVersion.MQTTv311
55 logger.debug("MQTT protocol set to %r", protocol)
57 self.event_loop = event_loop or asyncio.get_event_loop()
58 self.client = mqtt.Client(
59 callback_api_version=CallbackAPIVersion.VERSION2,
60 client_id=f"updates2mqtt_{self.node_cfg.name}",
61 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None,
62 protocol=protocol,
63 )
64 self.client.username_pw_set(self.cfg.user, password=self.cfg.password)
65 rc: MQTTErrorCode = self.client.connect(
66 host=self.cfg.host,
67 port=self.cfg.port,
68 keepalive=60,
69 clean_start=MQTT_CLEAN_START_FIRST_ONLY,
70 )
71 logger.info("Client connection requested", result_code=rc)
73 self.client.on_connect = self.on_connect
74 self.client.on_disconnect = self.on_disconnect
75 self.client.on_message = self.on_message
76 self.client.on_subscribe = self.on_subscribe
77 self.client.on_unsubscribe = self.on_unsubscribe
79 self.client.loop_start()
81 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port)
82 except Exception as e:
83 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e))
84 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e
86 def stop(self) -> None:
87 if self.client: 87 ↛ exitline 87 didn't return from function 'stop' because the condition on line 87 was always true
88 self.client.loop_stop()
89 self.client.disconnect()
90 self.client = None
92 def is_available(self) -> bool:
93 return self.client is not None and not self.fatal_failure.is_set()
95 def on_connect(
96 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None
97 ) -> None:
98 if not self.client or self.fatal_failure.is_set(): 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true
99 self.log.warn("No client, check if started and authorized")
100 return
101 if rc.getName() == "Not authorized":
102 self.fatal_failure.set()
103 log.error("Invalid MQTT credentials", result_code=rc)
104 return
105 if rc != 0: 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true
106 self.log.warning("Connection failed to broker", result_code=rc)
107 else:
108 self.log.debug("Connected to broker", result_code=rc)
109 for topic, provider in self.providers_by_topic.items():
110 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type)
111 self.client.subscribe(topic)
113 def on_disconnect(
114 self,
115 _client: mqtt.Client,
116 _userdata: Any,
117 _disconnect_flags: mqtt.DisconnectFlags,
118 rc: ReasonCode,
119 _props: Properties | None,
120 ) -> None:
121 if rc == 0: 121 ↛ 124line 121 didn't jump to line 124 because the condition on line 121 was always true
122 self.log.debug("Disconnected from broker", result_code=rc)
123 else:
124 self.log.warning("Disconnect failure from broker", result_code=rc)
126 async def clean_topics(
127 self, provider: ReleaseProvider, last_scan_session: str | None, wait_time: int = 5, force: bool = False
128 ) -> None:
129 logger = self.log.bind(action="clean")
130 if self.fatal_failure.is_set():
131 return
132 logger.info("Starting clean cycle")
133 cleaner = mqtt.Client(
134 callback_api_version=CallbackAPIVersion.VERSION1,
135 client_id=f"updates2mqtt_clean_{self.node_cfg.name}",
136 clean_session=True,
137 )
138 results = {"cleaned": 0, "handled": 0, "discovered": 0, "last_timestamp": time.time()}
139 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password)
140 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=60)
141 prefixes = [
142 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_",
143 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/",
144 ]
146 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
147 if msg.retain and any(msg.topic.startswith(prefix) for prefix in prefixes):
148 session = None
149 results["discovered"] += 1
150 try:
151 payload = self.safe_json_decode(msg.payload)
152 session = payload.get("source_session")
153 except Exception as e:
154 log.warn(
155 "Unable to handle payload for %s: %s",
156 msg.topic,
157 e,
158 exc_info=1,
159 )
160 results["handled"] += 1
161 results["last_timestamp"] = time.time()
162 if session is not None and last_scan_session is not None and session != last_scan_session:
163 log.debug("Removing stale msg", topic=msg.topic, session=session)
164 cleaner.publish(msg.topic, "", retain=True)
165 results["cleaned"] += 1
166 elif session is None and force:
167 log.debug("Removing untrackable msg", topic=msg.topic)
168 cleaner.publish(msg.topic, "", retain=True)
169 results["cleaned"] += 1
170 else:
171 log.debug(
172 "Retaining topic with current session: %s",
173 msg.topic,
174 )
175 else:
176 log.debug("Skipping clean of %s", msg.topic)
178 cleaner.on_message = cleanup
179 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True)
180 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options)
181 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options)
183 while time.time() - results["last_timestamp"] <= wait_time:
184 cleaner.loop(0.5)
186 log.info(
187 f"Clean completed, discovered:{results['discovered']}, handled:{results['handled']}, cleaned:{results['cleaned']}"
188 )
190 def safe_json_decode(self, jsonish: str | bytes | None) -> dict:
191 if jsonish is None:
192 return {}
193 try:
194 return json.loads(jsonish)
195 except Exception:
196 log.exception("JSON decode fail (%s)", jsonish)
197 try:
198 return json.loads(jsonish[1:-1])
199 except Exception:
200 log.exception("JSON decode fail (%s)", jsonish[1:-1])
201 return {}
203 async def execute_command(
204 self, msg: MQTTMessage | LocalMessage, on_update_start: Callable, on_update_end: Callable
205 ) -> None:
206 logger = self.log.bind(topic=msg.topic, payload=msg.payload)
207 comp_name: str | None = None
208 command: str | None = None
209 try:
210 logger.info("Execution starting")
211 source_type: str | None = None
213 payload: str | None = None
214 if isinstance(msg.payload, bytes):
215 payload = msg.payload.decode("utf-8")
216 elif isinstance(msg.payload, str): 216 ↛ 218line 216 didn't jump to line 218 because the condition on line 216 was always true
217 payload = msg.payload
218 if payload and "|" in payload:
219 source_type, comp_name, command = payload.split("|")
221 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None
222 if not provider: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true
223 logger.warn("Unexpected provider type %s", msg.topic)
224 elif provider.source_type != source_type:
225 logger.warn("Unexpected source type %s", source_type)
226 elif command != "install" or not comp_name: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true
227 logger.warn("Invalid payload in command message: %s", msg.payload)
228 else:
229 logger.info(
230 "Passing %s command to %s scanner for %s",
231 command,
232 source_type,
233 comp_name,
234 )
235 updated = provider.command(comp_name, command, on_update_start, on_update_end)
236 discovery = provider.resolve(comp_name)
237 if updated and discovery: 237 ↛ 245line 237 didn't jump to line 245 because the condition on line 237 was always true
238 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled:
239 self.publish_hass_config(discovery)
240 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
241 self.publish_discovery(discovery)
242 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT:
243 self.publish_hass_state(discovery)
244 else:
245 logger.debug("No change to republish after execution")
246 logger.info("Execution ended")
247 except Exception:
248 logger.exception("Execution failed", component=comp_name, command=command)
250 def local_message(self, discovery: Discovery, command: str) -> None:
251 """Simulate an incoming MQTT message for local commands"""
252 msg = LocalMessage(
253 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command])
254 )
255 self.handle_message(msg)
257 def on_subscribe(
258 self,
259 _client: mqtt.Client,
260 userdata: Any,
261 mid: int,
262 reason_code_list: list[ReasonCode],
263 properties: Properties | None = None,
264 ) -> None:
265 self.log.debug(
266 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
267 )
269 def on_unsubscribe(
270 self,
271 _client: mqtt.Client,
272 userdata: Any,
273 mid: int,
274 reason_code_list: list[ReasonCode],
275 properties: Properties | None = None,
276 ) -> None:
277 self.log.debug(
278 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties
279 )
281 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
282 """Callback for incoming MQTT messages""" # noqa: D401
283 if msg.topic in self.providers_by_topic:
284 self.handle_message(msg)
285 else:
286 # apparently the root non-wildcard sub sometimes brings in child topics
287 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload)
289 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None:
290 def update_start(discovery: Discovery) -> None:
291 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
292 self.publish_hass_state(discovery, in_progress=True)
294 def update_end(discovery: Discovery) -> None:
295 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT):
296 self.publish_hass_state(discovery, in_progress=False)
298 if self.event_loop is not None: 298 ↛ 301line 298 didn't jump to line 301 because the condition on line 298 was always true
299 asyncio.run_coroutine_threadsafe(self.execute_command(msg, update_start, update_end), self.event_loop)
300 else:
301 self.log.error("No event loop to handle message", topic=msg.topic)
303 def config_topic(self, discovery: Discovery) -> str:
304 prefix = self.hass_cfg.discovery.prefix
305 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config"
307 def state_topic(self, discovery: Discovery) -> str:
308 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state"
310 def general_topic(self, discovery: Discovery) -> str:
311 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}"
313 def command_topic(self, provider: ReleaseProvider) -> str:
314 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}"
316 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None:
317 """Comprehensive, non Home Assistant specific, base publication"""
318 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true
319 return
320 self.log.debug("Discovery publish: %s", discovery)
321 payload: dict[str, Any] = discovery.as_dict()
322 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment]
323 self.publish(self.general_topic(discovery), payload)
325 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None:
326 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true
327 return
328 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery)
329 self.publish(
330 self.state_topic(discovery),
331 hass_format_state(
332 discovery,
333 discovery.session,
334 in_progress=in_progress,
335 ),
336 )
338 def publish_hass_config(self, discovery: Discovery) -> None:
339 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true
340 return
341 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}"
342 self.log.debug("HASS Config: %s", object_id)
344 self.publish(
345 self.config_topic(discovery),
346 hass_format_config(
347 discovery=discovery,
348 object_id=object_id,
349 area=self.hass_cfg.area,
350 state_topic=self.state_topic(discovery),
351 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None,
352 command_topic=self.command_topic(discovery.provider),
353 force_command_topic=self.hass_cfg.force_command_topic,
354 device_creation=self.hass_cfg.device_creation,
355 ),
356 )
358 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201
359 topic = self.command_topic(provider)
360 if topic in self.providers_by_topic or self.client is None:
361 self.log.debug("Skipping subscription", topic=topic)
362 else:
363 self.log.info("Handler subscribing", topic=topic)
364 self.providers_by_topic[topic] = provider
365 self.client.subscribe(topic)
366 return topic
368 def loop_once(self) -> None:
369 if self.client:
370 self.client.loop()
372 def publish(self, topic: str, payload: dict, qos: int = 0, retain: bool = True) -> None:
373 if self.client:
374 self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain)