Coverage for src / updates2mqtt / mqtt.py: 79%

283 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-06-02 10:03 +0000

1import asyncio 

2import json 

3import re 

4import time 

5from collections.abc import Callable 

6from dataclasses import dataclass, field 

7from threading import Event 

8from typing import Any 

9 

10import paho.mqtt.client as mqtt 

11import paho.mqtt.subscribeoptions 

12import structlog 

13from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage, MQTTMessageInfo 

14from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion 

15from paho.mqtt.properties import Properties 

16from paho.mqtt.reasoncodes import ReasonCode 

17 

18from updates2mqtt.model import Discovery, ReleaseProvider 

19 

20from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy 

21from .hass_formatter import hass_format_config, hass_format_state 

22 

23log = structlog.get_logger() 

24 

25MQTT_NAME = r"[A-Za-z0-9_\-\.]+" 

26 

27 

28@dataclass 

29class LocalMessage: 

30 topic: str | None = field(default=None) 

31 payload: str | None = field(default=None) 

32 

33 

34class MqttPublisher: 

35 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None: 

36 self.cfg: MqttConfig = cfg 

37 self.node_cfg: NodeConfig = node_cfg 

38 self.hass_cfg: HomeAssistantConfig = hass_cfg 

39 self.providers_by_topic: dict[str, ReleaseProvider] = {} 

40 self.providers_by_type: dict[str, ReleaseProvider] = {} 

41 self.event_loop: asyncio.AbstractEventLoop | None = None 

42 self.client: mqtt.Client | None = None 

43 self.fatal_failure = Event() 

44 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt") 

45 

46 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None: 

47 logger = self.log.bind(action="start") 

48 try: 

49 protocol: MQTTProtocolVersion 

50 if self.cfg.protocol in ("3", "3.11"): 

51 protocol = MQTTProtocolVersion.MQTTv311 

52 elif self.cfg.protocol == "3.1": 

53 protocol = MQTTProtocolVersion.MQTTv31 

54 elif self.cfg.protocol in ("5", "5.0"): 

55 protocol = MQTTProtocolVersion.MQTTv5 

56 else: 

57 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol) 

58 protocol = MQTTProtocolVersion.MQTTv311 

59 logger.debug("MQTT protocol set to %r", protocol) 

60 

61 self.event_loop = event_loop or asyncio.get_event_loop() 

62 self.client = mqtt.Client( 

63 callback_api_version=CallbackAPIVersion.VERSION2, 

64 client_id=f"updates2mqtt_{self.node_cfg.name}", 

65 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None, 

66 protocol=protocol, 

67 ) 

68 self.client.username_pw_set(self.cfg.user, password=self.cfg.password) 

69 rc: MQTTErrorCode = self.client.connect( 

70 host=self.cfg.host, 

71 port=self.cfg.port, 

72 keepalive=60, 

73 clean_start=MQTT_CLEAN_START_FIRST_ONLY, 

74 ) 

75 logger.info("Client connection requested", result_code=rc) 

76 

77 self.client.on_connect = self.on_connect 

78 self.client.on_disconnect = self.on_disconnect 

79 self.client.on_message = self.on_message 

80 self.client.on_subscribe = self.on_subscribe 

81 self.client.on_unsubscribe = self.on_unsubscribe 

82 

83 self.client.loop_start() 

84 

85 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port) 

86 except Exception as e: 

87 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e)) 

88 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e 

89 

90 def stop(self) -> None: 

91 if self.client: 

92 self.client.loop_stop() 

93 self.client.disconnect() 

94 self.client = None 

95 

96 def is_available(self) -> bool: 

97 return self.client is not None and not self.fatal_failure.is_set() 

98 

99 def on_connect( 

100 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None 

101 ) -> None: 

102 if not self.client or self.fatal_failure.is_set(): 

103 self.log.warn("No client, check if started and authorized") 

104 return 

105 if rc.getName() == "Not authorized": 

106 self.fatal_failure.set() 

107 log.error("Invalid MQTT credentials", result_code=rc) 

108 return 

109 if rc != 0: 

110 self.log.warning("Connection failed to broker", result_code=rc) 

111 else: 

112 self.log.debug("Connected to broker", result_code=rc) 

113 for topic, provider in self.providers_by_topic.items(): 

114 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type) 

115 self.client.subscribe(topic) 

116 

117 def on_disconnect( 

118 self, 

119 _client: mqtt.Client, 

120 _userdata: Any, 

121 _disconnect_flags: mqtt.DisconnectFlags, 

122 rc: ReasonCode, 

123 _props: Properties | None, 

124 ) -> None: 

125 if rc == 0: 

126 self.log.debug("Disconnected from broker", result_code=rc) 

127 else: 

128 self.log.warning("Disconnect failure from broker", result_code=rc) 

129 

130 async def clean_topics( 

131 self, provider: ReleaseProvider, wait_time: int = 5, max_time: int = 120, initial: bool = False 

132 ) -> None: 

133 logger = self.log.bind(action="clean") 

134 

135 if self.fatal_failure.is_set(): 

136 return 

137 try: 

138 logger.info("Starting clean cycle, wait time: %s, max time: %s, initial: %s", wait_time, max_time, initial) 

139 cutoff_time: float = time.time() + max_time 

140 cleaner = mqtt.Client( 

141 callback_api_version=CallbackAPIVersion.VERSION1, 

142 client_id=f"updates2mqtt_clean_{self.node_cfg.name}", 

143 clean_session=True, 

144 ) 

145 results = {"cleaned": 0, "matched": 0, "discovered": 0, "last_timestamp": time.time()} 

146 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password) 

147 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=60) 

148 

149 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None: 

150 discovery: Discovery | None = None 

151 if msg.topic.startswith( 

152 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_" 

153 ): 

154 discovery = self.reverse_config_topic(msg.topic, provider.source_type) 

155 elif msg.topic.startswith( 155 ↛ 158line 155 didn't jump to line 158 because the condition on line 155 was never true

156 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/" 

157 ) and msg.topic.endswith("/state"): 

158 discovery = self.reverse_state_topic(msg.topic, provider.source_type) 

159 elif msg.topic.startswith(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"): 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 discovery = self.reverse_general_topic(msg.topic, provider.source_type) 

161 else: 

162 logger.debug("Ignoring other topic ", topic=msg.topic) 

163 return 

164 results["discovered"] += 1 

165 if not initial and discovery is None: 

166 logger.debug("Removing unknown discovery", topic=msg.topic) 

167 cleaner.publish(msg.topic, "", retain=True) 

168 results["cleaned"] += 1 

169 elif discovery is not None: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true

170 results["matched"] += 1 

171 

172 try: 

173 if msg.payload: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true

174 payload = json.loads(msg.payload) 

175 if payload.get("in_progress") and initial: 

176 cleaner.publish(msg.topic, "", retain=True) 

177 results["cleaned"] += 1 

178 except Exception as e: 

179 logger.warn("Invalid payload at %s: %s", msg.topic, e) 

180 cleaner.publish(msg.topic, "", retain=True) 

181 results["cleaned"] += 1 

182 

183 results["last_timestamp"] = time.time() 

184 

185 cleaner.on_message = cleanup 

186 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True) 

187 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options) 

188 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options) 

189 

190 while time.time() - results["last_timestamp"] <= wait_time and time.time() <= cutoff_time: 

191 cleaner.loop(0.5) 

192 

193 logger.info( 

194 f"Cleaned - discovered:{results['discovered']}, matched:{results['matched']}, cleaned:{results['cleaned']}" 

195 ) 

196 except Exception as e: 

197 logger.exception("Cleaning topics of stale entries failed: %s", e) 

198 

199 def safe_json_decode(self, jsonish: str | bytes | None) -> dict: 

200 if jsonish is None: 

201 return {} 

202 try: 

203 return json.loads(jsonish) 

204 except Exception: 

205 log.exception("JSON decode fail (%s)", jsonish) 

206 try: 

207 return json.loads(jsonish[1:-1]) 

208 except Exception: 

209 log.exception("JSON decode fail (%s)", jsonish[1:-1]) 

210 return {} 

211 

212 async def execute_command( 

213 self, msg: MQTTMessage | LocalMessage, on_update_start: Callable, on_update_end: Callable 

214 ) -> None: 

215 # TODO: defer handling of commands where repository is throttled 

216 logger = self.log.bind(topic=msg.topic, payload=msg.payload) 

217 comp_name: str | None = None 

218 command: str | None = None 

219 try: 

220 logger.info("Execution starting") 

221 source_type: str | None = None 

222 

223 payload: str | None = None 

224 if isinstance(msg.payload, bytes): 

225 payload = msg.payload.decode("utf-8") 

226 elif isinstance(msg.payload, str): 226 ↛ 228line 226 didn't jump to line 228 because the condition on line 226 was always true

227 payload = msg.payload 

228 if payload and "|" in payload: 

229 source_type, comp_name, command = payload.split("|") 

230 logger.debug("Executing %s:%s:%s", source_type, comp_name, command) 

231 

232 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None 

233 if not provider: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true

234 logger.warn("Unexpected provider type %s", msg.topic) 

235 elif provider.source_type != source_type: 

236 logger.warn("Unexpected source type %s", source_type) 

237 elif command != "install" or not comp_name: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true

238 logger.warn("Invalid payload in command message: %s", msg.payload) 

239 else: 

240 logger.info( 

241 "Passing %s command to %s scanner for %s", 

242 command, 

243 source_type, 

244 comp_name, 

245 ) 

246 updated: bool = provider.command(comp_name, command, on_update_start, on_update_end) 

247 discovery = provider.resolve(comp_name) 

248 if updated and discovery: 248 ↛ 256line 248 didn't jump to line 256 because the condition on line 248 was always true

249 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled: 

250 self.publish_hass_config(discovery) 

251 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

252 self.publish_discovery(discovery) 

253 if discovery and discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

254 self.publish_hass_state(discovery) 

255 else: 

256 logger.debug("No change to republish after execution") 

257 logger.info("Execution ended") 

258 except Exception: 

259 logger.exception("Execution failed", component=comp_name, command=command) 

260 

261 def local_message(self, discovery: Discovery, command: str) -> None: 

262 """Simulate an incoming MQTT message for local commands""" 

263 msg = LocalMessage( 

264 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command]) 

265 ) 

266 self.handle_message(msg) 

267 

268 def on_subscribe( 

269 self, 

270 _client: mqtt.Client, 

271 userdata: Any, 

272 mid: int, 

273 reason_code_list: list[ReasonCode], 

274 properties: Properties | None = None, 

275 ) -> None: 

276 self.log.debug( 

277 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties 

278 ) 

279 

280 def on_unsubscribe( 

281 self, 

282 _client: mqtt.Client, 

283 userdata: Any, 

284 mid: int, 

285 reason_code_list: list[ReasonCode], 

286 properties: Properties | None = None, 

287 ) -> None: 

288 self.log.debug( 

289 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties 

290 ) 

291 

292 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None: 

293 """Callback for incoming MQTT messages""" # noqa: D401 

294 if msg.topic in self.providers_by_topic: 

295 self.handle_message(msg) 

296 else: 

297 # apparently the root non-wildcard sub sometimes brings in child topics 

298 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload) 

299 

300 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None: 

301 def update_start(discovery: Discovery) -> None: 

302 self.log.debug("on_update_start: %s", topic=msg.topic) 

303 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

304 self.publish_hass_state(discovery, in_progress=True) 

305 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

306 self.publish_discovery(discovery, in_progress=True) 

307 

308 def update_end(discovery: Discovery) -> None: 

309 self.log.debug("on_update_end: %s", topic=msg.topic) 

310 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

311 self.publish_hass_state(discovery, in_progress=False) 

312 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

313 self.publish_discovery(discovery, in_progress=False) 

314 

315 # TODO: fix double publish on callback and in command exec 

316 if self.event_loop is not None: 316 ↛ 322line 316 didn't jump to line 322 because the condition on line 316 was always true

317 self.log.debug("Executing command topic", topic=msg.topic) 

318 asyncio.run_coroutine_threadsafe( 

319 self.execute_command(msg=msg, on_update_start=update_start, on_update_end=update_end), loop=self.event_loop 

320 ) 

321 else: 

322 self.log.error("No event loop to handle message", topic=msg.topic) 

323 

324 def config_topic(self, discovery: Discovery) -> str: 

325 prefix = self.hass_cfg.discovery.prefix 

326 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config" 

327 

328 def reverse_config_topic(self, topic: str, source_type: str) -> Discovery | None: 

329 match = re.fullmatch( 

330 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{source_type}_({MQTT_NAME})/update/config", 

331 topic, 

332 ) 

333 if match and len(match.groups()) == 1: 333 ↛ 338line 333 didn't jump to line 338 because the condition on line 333 was always true

334 discovery_name: str = match.group(1) 

335 if source_type in self.providers_by_type and discovery_name in self.providers_by_type[source_type].discoveries: 

336 return self.providers_by_type[source_type].discoveries[discovery_name] 

337 

338 self.log.debug("MQTT CONFIG no match for %s", topic) 

339 return None 

340 

341 def state_topic(self, discovery: Discovery) -> str: 

342 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state" 

343 

344 def reverse_state_topic(self, topic: str, source_type: str) -> Discovery | None: 

345 match = re.fullmatch( 

346 f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})/state", 

347 topic, 

348 ) 

349 if match and len(match.groups()) == 1: 

350 discovery_name: str = match.group(1) 

351 if discovery_name in self.providers_by_type[source_type].discoveries: 

352 return self.providers_by_type[source_type].discoveries[discovery_name] 

353 

354 self.log.debug("MQTT STATE no match for %s", topic) 

355 return None 

356 

357 def general_topic(self, discovery: Discovery) -> str: 

358 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}" 

359 

360 def reverse_general_topic(self, topic: str, source_type: str) -> Discovery | None: 

361 match = re.fullmatch(f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})", topic) 

362 if match and len(match.groups()) == 1: 

363 discovery_name: str = match.group(1) 

364 if discovery_name in self.providers_by_type[source_type].discoveries: 

365 return self.providers_by_type[source_type].discoveries[discovery_name] 

366 

367 self.log.debug("MQTT ATTR no match for %s", topic) 

368 return None 

369 

370 def command_topic(self, provider: ReleaseProvider) -> str: 

371 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}" 

372 

373 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None: 

374 """Comprehensive, non Home Assistant specific, base publication""" 

375 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 375 ↛ 376line 375 didn't jump to line 376 because the condition on line 375 was never true

376 return 

377 self.log.debug("Discovery publish: %s", discovery) 

378 payload: dict[str, Any] = discovery.as_dict() 

379 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment] 

380 if payload.get("release", {}).get("summary") and self.hass_cfg.release_summary_max_size: 380 ↛ 381line 380 didn't jump to line 381 because the condition on line 380 was never true

381 payload["release"]["summary"] = payload["release"]["summary"][: self.hass_cfg.release_summary_max_size] 

382 self.publish(self.general_topic(discovery), payload) 

383 

384 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None: 

385 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true

386 return 

387 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery) 

388 self.publish( 

389 self.state_topic(discovery), 

390 hass_format_state( 

391 discovery, in_progress=in_progress, release_summary_max_size=self.hass_cfg.release_summary_max_size 

392 ), 

393 ) 

394 

395 def publish_hass_config(self, discovery: Discovery) -> None: 

396 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 396 ↛ 397line 396 didn't jump to line 397 because the condition on line 396 was never true

397 return 

398 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}" 

399 self.log.debug("HASS Config: %s", object_id) 

400 

401 self.publish( 

402 self.config_topic(discovery), 

403 hass_format_config( 

404 discovery=discovery, 

405 object_id=object_id, 

406 area=self.hass_cfg.area, 

407 state_topic=self.state_topic(discovery), 

408 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None, 

409 command_topic=self.command_topic(discovery.provider), 

410 force_command_topic=self.hass_cfg.force_command_topic, 

411 device_creation=self.hass_cfg.device_creation, 

412 ), 

413 ) 

414 

415 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201 

416 topic = self.command_topic(provider) 

417 if topic in self.providers_by_topic or self.client is None: 

418 self.log.debug("Skipping subscription", topic=topic) 

419 else: 

420 self.log.info("Handler subscribing", topic=topic) 

421 self.providers_by_topic[topic] = provider 

422 self.providers_by_type[provider.source_type] = provider 

423 self.client.subscribe(topic) 

424 return topic 

425 

426 def loop_once(self) -> None: 

427 if self.client: 

428 self.client.loop() 

429 

430 def publish(self, topic: str, payload: dict, qos: int = 0, retain: bool = True) -> None: 

431 if self.client: 

432 info: MQTTMessageInfo = self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain) 

433 if info.rc == MQTTErrorCode.MQTT_ERR_SUCCESS: 

434 self.log.debug("Publish to %s, mid: %s, published: %s, rc: %s", topic, info.mid, info.is_published(), info.rc) 

435 else: 

436 self.log.warning("Problem publishing to %s, mid: %s, rc: %s", topic, info.mid, info.rc) 

437 else: 

438 self.log.debug("No client to publish at %s", topic)