Coverage for src / updates2mqtt / mqtt.py: 77%

281 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-03 23:58 +0000

1import asyncio 

2import json 

3import re 

4import time 

5from collections.abc import Callable 

6from dataclasses import dataclass, field 

7from threading import Event 

8from typing import Any 

9 

10import paho.mqtt.client as mqtt 

11import paho.mqtt.subscribeoptions 

12import structlog 

13from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage, MQTTMessageInfo 

14from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion 

15from paho.mqtt.properties import Properties 

16from paho.mqtt.reasoncodes import ReasonCode 

17 

18from updates2mqtt.model import Discovery, ReleaseProvider 

19 

20from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy 

21from .hass_formatter import hass_format_config, hass_format_state 

22 

23log = structlog.get_logger() 

24 

25MQTT_NAME = r"[A-Za-z0-9_\-\.]+" 

26 

27 

28@dataclass 

29class LocalMessage: 

30 topic: str | None = field(default=None) 

31 payload: str | None = field(default=None) 

32 

33 

34class MqttPublisher: 

35 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None: 

36 self.cfg: MqttConfig = cfg 

37 self.node_cfg: NodeConfig = node_cfg 

38 self.hass_cfg: HomeAssistantConfig = hass_cfg 

39 self.providers_by_topic: dict[str, ReleaseProvider] = {} 

40 self.providers_by_type: dict[str, ReleaseProvider] = {} 

41 self.event_loop: asyncio.AbstractEventLoop | None = None 

42 self.client: mqtt.Client | None = None 

43 self.fatal_failure = Event() 

44 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt") 

45 

46 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None: 

47 logger = self.log.bind(action="start") 

48 try: 

49 protocol: MQTTProtocolVersion 

50 if self.cfg.protocol in ("3", "3.11"): 

51 protocol = MQTTProtocolVersion.MQTTv311 

52 elif self.cfg.protocol == "3.1": 

53 protocol = MQTTProtocolVersion.MQTTv31 

54 elif self.cfg.protocol in ("5", "5.0"): 

55 protocol = MQTTProtocolVersion.MQTTv5 

56 else: 

57 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol) 

58 protocol = MQTTProtocolVersion.MQTTv311 

59 logger.debug("MQTT protocol set to %r", protocol) 

60 

61 self.event_loop = event_loop or asyncio.get_event_loop() 

62 self.client = mqtt.Client( 

63 callback_api_version=CallbackAPIVersion.VERSION2, 

64 client_id=f"updates2mqtt_{self.node_cfg.name}", 

65 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None, 

66 protocol=protocol, 

67 ) 

68 self.client.username_pw_set(self.cfg.user, password=self.cfg.password) 

69 rc: MQTTErrorCode = self.client.connect( 

70 host=self.cfg.host, 

71 port=self.cfg.port, 

72 keepalive=60, 

73 clean_start=MQTT_CLEAN_START_FIRST_ONLY, 

74 ) 

75 logger.info("Client connection requested", result_code=rc) 

76 

77 self.client.on_connect = self.on_connect 

78 self.client.on_disconnect = self.on_disconnect 

79 self.client.on_message = self.on_message 

80 self.client.on_subscribe = self.on_subscribe 

81 self.client.on_unsubscribe = self.on_unsubscribe 

82 

83 self.client.loop_start() 

84 

85 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port) 

86 except Exception as e: 

87 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e)) 

88 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e 

89 

90 def stop(self) -> None: 

91 if self.client: 91 ↛ exitline 91 didn't return from function 'stop' because the condition on line 91 was always true

92 self.client.loop_stop() 

93 self.client.disconnect() 

94 self.client = None 

95 

96 def is_available(self) -> bool: 

97 return self.client is not None and not self.fatal_failure.is_set() 

98 

99 def on_connect( 

100 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None 

101 ) -> None: 

102 if not self.client or self.fatal_failure.is_set(): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 self.log.warn("No client, check if started and authorized") 

104 return 

105 if rc.getName() == "Not authorized": 

106 self.fatal_failure.set() 

107 log.error("Invalid MQTT credentials", result_code=rc) 

108 return 

109 if rc != 0: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 self.log.warning("Connection failed to broker", result_code=rc) 

111 else: 

112 self.log.debug("Connected to broker", result_code=rc) 

113 for topic, provider in self.providers_by_topic.items(): 

114 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type) 

115 self.client.subscribe(topic) 

116 

117 def on_disconnect( 

118 self, 

119 _client: mqtt.Client, 

120 _userdata: Any, 

121 _disconnect_flags: mqtt.DisconnectFlags, 

122 rc: ReasonCode, 

123 _props: Properties | None, 

124 ) -> None: 

125 if rc == 0: 125 ↛ 128line 125 didn't jump to line 128 because the condition on line 125 was always true

126 self.log.debug("Disconnected from broker", result_code=rc) 

127 else: 

128 self.log.warning("Disconnect failure from broker", result_code=rc) 

129 

130 async def clean_topics( 

131 self, provider: ReleaseProvider, wait_time: int = 5, max_time: int = 120, initial: bool = False 

132 ) -> None: 

133 logger = self.log.bind(action="clean") 

134 if self.fatal_failure.is_set(): 

135 return 

136 try: 

137 logger.info("Starting clean cycle, wait time: %s, max time: %s, initial: %s", wait_time, max_time, initial) 

138 cutoff_time: float = time.time() + max_time 

139 cleaner = mqtt.Client( 

140 callback_api_version=CallbackAPIVersion.VERSION1, 

141 client_id=f"updates2mqtt_clean_{self.node_cfg.name}", 

142 clean_session=True, 

143 ) 

144 results = {"cleaned": 0, "matched": 0, "discovered": 0, "last_timestamp": time.time()} 

145 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password) 

146 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=60) 

147 

148 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None: 

149 discovery: Discovery | None = None 

150 if msg.topic.startswith( 

151 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_" 

152 ): 

153 discovery = self.reverse_config_topic(msg.topic, provider.source_type) 

154 elif msg.topic.startswith( 154 ↛ 157line 154 didn't jump to line 157 because the condition on line 154 was never true

155 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/" 

156 ) and msg.topic.endswith("/state"): 

157 discovery = self.reverse_state_topic(msg.topic, provider.source_type) 

158 elif msg.topic.startswith(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/"): 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true

159 discovery = self.reverse_general_topic(msg.topic, provider.source_type) 

160 else: 

161 logger.debug("Ignoring other topic ", topic=msg.topic) 

162 return 

163 results["discovered"] += 1 

164 if not initial and discovery is None: 

165 logger.debug("Removing unknown discovery", topic=msg.topic) 

166 cleaner.publish(msg.topic, "", retain=True) 

167 results["cleaned"] += 1 

168 elif discovery is not None: 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was always true

169 results["matched"] += 1 

170 

171 try: 

172 if msg.payload: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 payload = json.loads(msg.payload) 

174 if payload.get("in_progress") and initial: 

175 cleaner.publish(msg.topic, "", retain=True) 

176 results["cleaned"] += 1 

177 except Exception as e: 

178 logger.warn("Invalid payload at %s: %s", msg.topic, e) 

179 cleaner.publish(msg.topic, "", retain=True) 

180 results["cleaned"] += 1 

181 

182 results["last_timestamp"] = time.time() 

183 

184 cleaner.on_message = cleanup 

185 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True) 

186 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options) 

187 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options) 

188 

189 while time.time() - results["last_timestamp"] <= wait_time and time.time() <= cutoff_time: 

190 cleaner.loop(0.5) 

191 

192 logger.info( 

193 f"Cleaned - discovered:{results['discovered']}, matched:{results['matched']}, cleaned:{results['cleaned']}" 

194 ) 

195 except Exception as e: 

196 logger.exception("Cleaning topics of stale entries failed: %s", e) 

197 

198 def safe_json_decode(self, jsonish: str | bytes | None) -> dict: 

199 if jsonish is None: 

200 return {} 

201 try: 

202 return json.loads(jsonish) 

203 except Exception: 

204 log.exception("JSON decode fail (%s)", jsonish) 

205 try: 

206 return json.loads(jsonish[1:-1]) 

207 except Exception: 

208 log.exception("JSON decode fail (%s)", jsonish[1:-1]) 

209 return {} 

210 

211 async def execute_command( 

212 self, msg: MQTTMessage | LocalMessage, on_update_start: Callable, on_update_end: Callable 

213 ) -> None: 

214 # TODO: defer handling of commands where repository is throttled 

215 logger = self.log.bind(topic=msg.topic, payload=msg.payload) 

216 comp_name: str | None = None 

217 command: str | None = None 

218 try: 

219 logger.info("Execution starting") 

220 source_type: str | None = None 

221 

222 payload: str | None = None 

223 if isinstance(msg.payload, bytes): 

224 payload = msg.payload.decode("utf-8") 

225 elif isinstance(msg.payload, str): 225 ↛ 227line 225 didn't jump to line 227 because the condition on line 225 was always true

226 payload = msg.payload 

227 if payload and "|" in payload: 

228 source_type, comp_name, command = payload.split("|") 

229 logger.debug("Executing %s:%s:%s", source_type, comp_name, command) 

230 

231 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None 

232 if not provider: 232 ↛ 233line 232 didn't jump to line 233 because the condition on line 232 was never true

233 logger.warn("Unexpected provider type %s", msg.topic) 

234 elif provider.source_type != source_type: 

235 logger.warn("Unexpected source type %s", source_type) 

236 elif command != "install" or not comp_name: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true

237 logger.warn("Invalid payload in command message: %s", msg.payload) 

238 else: 

239 logger.info( 

240 "Passing %s command to %s scanner for %s", 

241 command, 

242 source_type, 

243 comp_name, 

244 ) 

245 updated: bool = provider.command(comp_name, command, on_update_start, on_update_end) 

246 discovery = provider.resolve(comp_name) 

247 if updated and discovery: 247 ↛ 255line 247 didn't jump to line 255 because the condition on line 247 was always true

248 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled: 

249 self.publish_hass_config(discovery) 

250 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

251 self.publish_discovery(discovery) 

252 if discovery and discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

253 self.publish_hass_state(discovery) 

254 else: 

255 logger.debug("No change to republish after execution") 

256 logger.info("Execution ended") 

257 except Exception: 

258 logger.exception("Execution failed", component=comp_name, command=command) 

259 

260 def local_message(self, discovery: Discovery, command: str) -> None: 

261 """Simulate an incoming MQTT message for local commands""" 

262 msg = LocalMessage( 

263 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command]) 

264 ) 

265 self.handle_message(msg) 

266 

267 def on_subscribe( 

268 self, 

269 _client: mqtt.Client, 

270 userdata: Any, 

271 mid: int, 

272 reason_code_list: list[ReasonCode], 

273 properties: Properties | None = None, 

274 ) -> None: 

275 self.log.debug( 

276 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties 

277 ) 

278 

279 def on_unsubscribe( 

280 self, 

281 _client: mqtt.Client, 

282 userdata: Any, 

283 mid: int, 

284 reason_code_list: list[ReasonCode], 

285 properties: Properties | None = None, 

286 ) -> None: 

287 self.log.debug( 

288 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties 

289 ) 

290 

291 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None: 

292 """Callback for incoming MQTT messages""" # noqa: D401 

293 if msg.topic in self.providers_by_topic: 

294 self.handle_message(msg) 

295 else: 

296 # apparently the root non-wildcard sub sometimes brings in child topics 

297 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload) 

298 

299 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None: 

300 def update_start(discovery: Discovery) -> None: 

301 self.log.debug("on_update_start: %s", topic=msg.topic) 

302 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

303 self.publish_hass_state(discovery, in_progress=True) 

304 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

305 self.publish_discovery(discovery, in_progress=True) 

306 

307 def update_end(discovery: Discovery) -> None: 

308 self.log.debug("on_update_end: %s", topic=msg.topic) 

309 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

310 self.publish_hass_state(discovery, in_progress=False) 

311 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

312 self.publish_discovery(discovery, in_progress=False) 

313 

314 # TODO: fix double publish on callback and in command exec 

315 if self.event_loop is not None: 315 ↛ 321line 315 didn't jump to line 321 because the condition on line 315 was always true

316 self.log.debug("Executing command topic", topic=msg.topic) 

317 asyncio.run_coroutine_threadsafe( 

318 self.execute_command(msg=msg, on_update_start=update_start, on_update_end=update_end), loop=self.event_loop 

319 ) 

320 else: 

321 self.log.error("No event loop to handle message", topic=msg.topic) 

322 

323 def config_topic(self, discovery: Discovery) -> str: 

324 prefix = self.hass_cfg.discovery.prefix 

325 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config" 

326 

327 def reverse_config_topic(self, topic: str, source_type: str) -> Discovery | None: 

328 match = re.fullmatch( 

329 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{source_type}_({MQTT_NAME})/update/config", 

330 topic, 

331 ) 

332 if match and len(match.groups()) == 1: 332 ↛ 337line 332 didn't jump to line 337 because the condition on line 332 was always true

333 discovery_name: str = match.group(1) 

334 if source_type in self.providers_by_type and discovery_name in self.providers_by_type[source_type].discoveries: 

335 return self.providers_by_type[source_type].discoveries[discovery_name] 

336 

337 self.log.debug("MQTT CONFIG no match for %s", topic) 

338 return None 

339 

340 def state_topic(self, discovery: Discovery) -> str: 

341 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state" 

342 

343 def reverse_state_topic(self, topic: str, source_type: str) -> Discovery | None: 

344 match = re.fullmatch( 

345 f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})/state", 

346 topic, 

347 ) 

348 if match and len(match.groups()) == 1: 

349 discovery_name: str = match.group(1) 

350 if discovery_name in self.providers_by_type[source_type].discoveries: 

351 return self.providers_by_type[source_type].discoveries[discovery_name] 

352 

353 self.log.debug("MQTT STATE no match for %s", topic) 

354 return None 

355 

356 def general_topic(self, discovery: Discovery) -> str: 

357 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}" 

358 

359 def reverse_general_topic(self, topic: str, source_type: str) -> Discovery | None: 

360 match = re.fullmatch(f"{self.cfg.topic_root}/{self.node_cfg.name}/{source_type}/({MQTT_NAME})", topic) 

361 if match and len(match.groups()) == 1: 

362 discovery_name: str = match.group(1) 

363 if discovery_name in self.providers_by_type[source_type].discoveries: 

364 return self.providers_by_type[source_type].discoveries[discovery_name] 

365 

366 self.log.debug("MQTT ATTR no match for %s", topic) 

367 return None 

368 

369 def command_topic(self, provider: ReleaseProvider) -> str: 

370 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}" 

371 

372 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None: 

373 """Comprehensive, non Home Assistant specific, base publication""" 

374 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true

375 return 

376 self.log.debug("Discovery publish: %s", discovery) 

377 payload: dict[str, Any] = discovery.as_dict() 

378 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment] 

379 self.publish(self.general_topic(discovery), payload) 

380 

381 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None: 

382 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 382 ↛ 383line 382 didn't jump to line 383 because the condition on line 382 was never true

383 return 

384 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery) 

385 self.publish( 

386 self.state_topic(discovery), 

387 hass_format_state( 

388 discovery, in_progress=in_progress, release_summary_max_size=self.hass_cfg.release_summary_max_size 

389 ), 

390 ) 

391 

392 def publish_hass_config(self, discovery: Discovery) -> None: 

393 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 393 ↛ 394line 393 didn't jump to line 394 because the condition on line 393 was never true

394 return 

395 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}" 

396 self.log.debug("HASS Config: %s", object_id) 

397 

398 self.publish( 

399 self.config_topic(discovery), 

400 hass_format_config( 

401 discovery=discovery, 

402 object_id=object_id, 

403 area=self.hass_cfg.area, 

404 state_topic=self.state_topic(discovery), 

405 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None, 

406 command_topic=self.command_topic(discovery.provider), 

407 force_command_topic=self.hass_cfg.force_command_topic, 

408 device_creation=self.hass_cfg.device_creation, 

409 ), 

410 ) 

411 

412 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201 

413 topic = self.command_topic(provider) 

414 if topic in self.providers_by_topic or self.client is None: 

415 self.log.debug("Skipping subscription", topic=topic) 

416 else: 

417 self.log.info("Handler subscribing", topic=topic) 

418 self.providers_by_topic[topic] = provider 

419 self.providers_by_type[provider.source_type] = provider 

420 self.client.subscribe(topic) 

421 return topic 

422 

423 def loop_once(self) -> None: 

424 if self.client: 

425 self.client.loop() 

426 

427 def publish(self, topic: str, payload: dict, qos: int = 0, retain: bool = True) -> None: 

428 if self.client: 

429 info: MQTTMessageInfo = self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain) 

430 if info.rc == MQTTErrorCode.MQTT_ERR_SUCCESS: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true

431 self.log.debug("Publish to %s, mid: %s, published: %s, rc: %s", topic, info.mid, info.is_published(), info.rc) 

432 else: 

433 self.log.warning("Problem publishing to %s, mid: %s, rc: %s", topic, info.mid, info.rc) 

434 else: 

435 self.log.debug("No client to publish at %s", topic)