Coverage for src / updates2mqtt / mqtt.py: 88%

231 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-20 02:29 +0000

1import asyncio 

2import json 

3import time 

4from collections.abc import Callable 

5from dataclasses import dataclass, field 

6from threading import Event 

7from typing import Any 

8 

9import paho.mqtt.client as mqtt 

10import paho.mqtt.subscribeoptions 

11import structlog 

12from paho.mqtt.client import MQTT_CLEAN_START_FIRST_ONLY, MQTTMessage 

13from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode, MQTTProtocolVersion 

14from paho.mqtt.properties import Properties 

15from paho.mqtt.reasoncodes import ReasonCode 

16 

17from updates2mqtt.model import Discovery, ReleaseProvider 

18 

19from .config import HomeAssistantConfig, MqttConfig, NodeConfig, PublishPolicy 

20from .hass_formatter import hass_format_config, hass_format_state 

21 

22log = structlog.get_logger() 

23 

24 

25@dataclass 

26class LocalMessage: 

27 topic: str | None = field(default=None) 

28 payload: str | None = field(default=None) 

29 

30 

31class MqttPublisher: 

32 def __init__(self, cfg: MqttConfig, node_cfg: NodeConfig, hass_cfg: HomeAssistantConfig) -> None: 

33 self.cfg: MqttConfig = cfg 

34 self.node_cfg: NodeConfig = node_cfg 

35 self.hass_cfg: HomeAssistantConfig = hass_cfg 

36 self.providers_by_topic: dict[str, ReleaseProvider] = {} 

37 self.event_loop: asyncio.AbstractEventLoop | None = None 

38 self.client: mqtt.Client | None = None 

39 self.fatal_failure = Event() 

40 self.log = structlog.get_logger().bind(host=cfg.host, integration="mqtt") 

41 

42 def start(self, event_loop: asyncio.AbstractEventLoop | None = None) -> None: 

43 logger = self.log.bind(action="start") 

44 try: 

45 protocol: MQTTProtocolVersion 

46 if self.cfg.protocol in ("3", "3.11"): 

47 protocol = MQTTProtocolVersion.MQTTv311 

48 elif self.cfg.protocol == "3.1": 

49 protocol = MQTTProtocolVersion.MQTTv31 

50 elif self.cfg.protocol in ("5", "5.0"): 

51 protocol = MQTTProtocolVersion.MQTTv5 

52 else: 

53 logger.info("No valid MQTT protocol version found (%s), setting to default v3.11", self.cfg.protocol) 

54 protocol = MQTTProtocolVersion.MQTTv311 

55 logger.debug("MQTT protocol set to %r", protocol) 

56 

57 self.event_loop = event_loop or asyncio.get_event_loop() 

58 self.client = mqtt.Client( 

59 callback_api_version=CallbackAPIVersion.VERSION2, 

60 client_id=f"updates2mqtt_{self.node_cfg.name}", 

61 clean_session=True if protocol != MQTTProtocolVersion.MQTTv5 else None, 

62 protocol=protocol, 

63 ) 

64 self.client.username_pw_set(self.cfg.user, password=self.cfg.password) 

65 rc: MQTTErrorCode = self.client.connect( 

66 host=self.cfg.host, 

67 port=self.cfg.port, 

68 keepalive=60, 

69 clean_start=MQTT_CLEAN_START_FIRST_ONLY, 

70 ) 

71 logger.info("Client connection requested", result_code=rc) 

72 

73 self.client.on_connect = self.on_connect 

74 self.client.on_disconnect = self.on_disconnect 

75 self.client.on_message = self.on_message 

76 self.client.on_subscribe = self.on_subscribe 

77 self.client.on_unsubscribe = self.on_unsubscribe 

78 

79 self.client.loop_start() 

80 

81 logger.debug("MQTT Publisher loop started", host=self.cfg.host, port=self.cfg.port) 

82 except Exception as e: 

83 logger.error("Failed to connect to broker", host=self.cfg.host, port=self.cfg.port, error=str(e)) 

84 raise OSError(f"Connection Failure to {self.cfg.host}:{self.cfg.port} as {self.cfg.user} -- {e}") from e 

85 

86 def stop(self) -> None: 

87 if self.client: 87 ↛ exitline 87 didn't return from function 'stop' because the condition on line 87 was always true

88 self.client.loop_stop() 

89 self.client.disconnect() 

90 self.client = None 

91 

92 def is_available(self) -> bool: 

93 return self.client is not None and not self.fatal_failure.is_set() 

94 

95 def on_connect( 

96 self, _client: mqtt.Client, _userdata: Any, _flags: mqtt.ConnectFlags, rc: ReasonCode, _props: Properties | None 

97 ) -> None: 

98 if not self.client or self.fatal_failure.is_set(): 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 self.log.warn("No client, check if started and authorized") 

100 return 

101 if rc.getName() == "Not authorized": 

102 self.fatal_failure.set() 

103 log.error("Invalid MQTT credentials", result_code=rc) 

104 return 

105 if rc != 0: 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true

106 self.log.warning("Connection failed to broker", result_code=rc) 

107 else: 

108 self.log.debug("Connected to broker", result_code=rc) 

109 for topic, provider in self.providers_by_topic.items(): 

110 self.log.debug("(Re)subscribing", topic=topic, provider=provider.source_type) 

111 self.client.subscribe(topic) 

112 

113 def on_disconnect( 

114 self, 

115 _client: mqtt.Client, 

116 _userdata: Any, 

117 _disconnect_flags: mqtt.DisconnectFlags, 

118 rc: ReasonCode, 

119 _props: Properties | None, 

120 ) -> None: 

121 if rc == 0: 121 ↛ 124line 121 didn't jump to line 124 because the condition on line 121 was always true

122 self.log.debug("Disconnected from broker", result_code=rc) 

123 else: 

124 self.log.warning("Disconnect failure from broker", result_code=rc) 

125 

126 async def clean_topics( 

127 self, provider: ReleaseProvider, last_scan_session: str | None, wait_time: int = 5, force: bool = False 

128 ) -> None: 

129 logger = self.log.bind(action="clean") 

130 if self.fatal_failure.is_set(): 

131 return 

132 logger.info("Starting clean cycle") 

133 cleaner = mqtt.Client( 

134 callback_api_version=CallbackAPIVersion.VERSION1, 

135 client_id=f"updates2mqtt_clean_{self.node_cfg.name}", 

136 clean_session=True, 

137 ) 

138 results = {"cleaned": 0, "handled": 0, "discovered": 0, "last_timestamp": time.time()} 

139 cleaner.username_pw_set(self.cfg.user, password=self.cfg.password) 

140 cleaner.connect(host=self.cfg.host, port=self.cfg.port, keepalive=60) 

141 prefixes = [ 

142 f"{self.hass_cfg.discovery.prefix}/update/{self.node_cfg.name}_{provider.source_type}_", 

143 f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/", 

144 ] 

145 

146 def cleanup(_client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None: 

147 if msg.retain and any(msg.topic.startswith(prefix) for prefix in prefixes): 

148 session = None 

149 results["discovered"] += 1 

150 try: 

151 payload = self.safe_json_decode(msg.payload) 

152 session = payload.get("source_session") 

153 except Exception as e: 

154 log.warn( 

155 "Unable to handle payload for %s: %s", 

156 msg.topic, 

157 e, 

158 exc_info=1, 

159 ) 

160 results["handled"] += 1 

161 results["last_timestamp"] = time.time() 

162 if session is not None and last_scan_session is not None and session != last_scan_session: 

163 log.debug("Removing stale msg", topic=msg.topic, session=session) 

164 cleaner.publish(msg.topic, "", retain=True) 

165 results["cleaned"] += 1 

166 elif session is None and force: 

167 log.debug("Removing untrackable msg", topic=msg.topic) 

168 cleaner.publish(msg.topic, "", retain=True) 

169 results["cleaned"] += 1 

170 else: 

171 log.debug( 

172 "Retaining topic with current session: %s", 

173 msg.topic, 

174 ) 

175 else: 

176 log.debug("Skipping clean of %s", msg.topic) 

177 

178 cleaner.on_message = cleanup 

179 options = paho.mqtt.subscribeoptions.SubscribeOptions(noLocal=True) 

180 cleaner.subscribe(f"{self.hass_cfg.discovery.prefix}/update/#", options=options) 

181 cleaner.subscribe(f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}/#", options=options) 

182 

183 while time.time() - results["last_timestamp"] <= wait_time: 

184 cleaner.loop(0.5) 

185 

186 log.info( 

187 f"Clean completed, discovered:{results['discovered']}, handled:{results['handled']}, cleaned:{results['cleaned']}" 

188 ) 

189 

190 def safe_json_decode(self, jsonish: str | bytes | None) -> dict: 

191 if jsonish is None: 

192 return {} 

193 try: 

194 return json.loads(jsonish) 

195 except Exception: 

196 log.exception("JSON decode fail (%s)", jsonish) 

197 try: 

198 return json.loads(jsonish[1:-1]) 

199 except Exception: 

200 log.exception("JSON decode fail (%s)", jsonish[1:-1]) 

201 return {} 

202 

203 async def execute_command( 

204 self, msg: MQTTMessage | LocalMessage, on_update_start: Callable, on_update_end: Callable 

205 ) -> None: 

206 logger = self.log.bind(topic=msg.topic, payload=msg.payload) 

207 comp_name: str | None = None 

208 command: str | None = None 

209 try: 

210 logger.info("Execution starting") 

211 source_type: str | None = None 

212 

213 payload: str | None = None 

214 if isinstance(msg.payload, bytes): 

215 payload = msg.payload.decode("utf-8") 

216 elif isinstance(msg.payload, str): 216 ↛ 218line 216 didn't jump to line 218 because the condition on line 216 was always true

217 payload = msg.payload 

218 if payload and "|" in payload: 

219 source_type, comp_name, command = payload.split("|") 

220 

221 provider: ReleaseProvider | None = self.providers_by_topic.get(msg.topic) if msg.topic else None 

222 if not provider: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true

223 logger.warn("Unexpected provider type %s", msg.topic) 

224 elif provider.source_type != source_type: 

225 logger.warn("Unexpected source type %s", source_type) 

226 elif command != "install" or not comp_name: 226 ↛ 227line 226 didn't jump to line 227 because the condition on line 226 was never true

227 logger.warn("Invalid payload in command message: %s", msg.payload) 

228 else: 

229 logger.info( 

230 "Passing %s command to %s scanner for %s", 

231 command, 

232 source_type, 

233 comp_name, 

234 ) 

235 updated = provider.command(comp_name, command, on_update_start, on_update_end) 

236 discovery = provider.resolve(comp_name) 

237 if updated and discovery: 237 ↛ 245line 237 didn't jump to line 245 because the condition on line 237 was always true

238 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT and self.hass_cfg.discovery.enabled: 

239 self.publish_hass_config(discovery) 

240 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

241 self.publish_discovery(discovery) 

242 if discovery.publish_policy == PublishPolicy.HOMEASSISTANT: 

243 self.publish_hass_state(discovery) 

244 else: 

245 logger.debug("No change to republish after execution") 

246 logger.info("Execution ended") 

247 except Exception: 

248 logger.exception("Execution failed", component=comp_name, command=command) 

249 

250 def local_message(self, discovery: Discovery, command: str) -> None: 

251 """Simulate an incoming MQTT message for local commands""" 

252 msg = LocalMessage( 

253 topic=self.command_topic(discovery.provider), payload="|".join([discovery.source_type, discovery.name, command]) 

254 ) 

255 self.handle_message(msg) 

256 

257 def on_subscribe( 

258 self, 

259 _client: mqtt.Client, 

260 userdata: Any, 

261 mid: int, 

262 reason_code_list: list[ReasonCode], 

263 properties: Properties | None = None, 

264 ) -> None: 

265 self.log.debug( 

266 "on_subscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties 

267 ) 

268 

269 def on_unsubscribe( 

270 self, 

271 _client: mqtt.Client, 

272 userdata: Any, 

273 mid: int, 

274 reason_code_list: list[ReasonCode], 

275 properties: Properties | None = None, 

276 ) -> None: 

277 self.log.debug( 

278 "on_unsubscribe, userdata=%s, mid=%s, reasons=%s, properties=%s", userdata, mid, reason_code_list, properties 

279 ) 

280 

281 def on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None: 

282 """Callback for incoming MQTT messages""" # noqa: D401 

283 if msg.topic in self.providers_by_topic: 

284 self.handle_message(msg) 

285 else: 

286 # apparently the root non-wildcard sub sometimes brings in child topics 

287 self.log.debug("Unhandled message #%s on %s:%s", msg.mid, msg.topic, msg.payload) 

288 

289 def handle_message(self, msg: mqtt.MQTTMessage | LocalMessage) -> None: 

290 def update_start(discovery: Discovery) -> None: 

291 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

292 self.publish_hass_state(discovery, in_progress=True) 

293 

294 def update_end(discovery: Discovery) -> None: 

295 if discovery.publish_policy in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 

296 self.publish_hass_state(discovery, in_progress=False) 

297 

298 if self.event_loop is not None: 298 ↛ 301line 298 didn't jump to line 301 because the condition on line 298 was always true

299 asyncio.run_coroutine_threadsafe(self.execute_command(msg, update_start, update_end), self.event_loop) 

300 else: 

301 self.log.error("No event loop to handle message", topic=msg.topic) 

302 

303 def config_topic(self, discovery: Discovery) -> str: 

304 prefix = self.hass_cfg.discovery.prefix 

305 return f"{prefix}/update/{self.node_cfg.name}_{discovery.source_type}_{discovery.name}/update/config" 

306 

307 def state_topic(self, discovery: Discovery) -> str: 

308 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}/state" 

309 

310 def general_topic(self, discovery: Discovery) -> str: 

311 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{discovery.source_type}/{discovery.name}" 

312 

313 def command_topic(self, provider: ReleaseProvider) -> str: 

314 return f"{self.cfg.topic_root}/{self.node_cfg.name}/{provider.source_type}" 

315 

316 def publish_discovery(self, discovery: Discovery, in_progress: bool = False) -> None: 

317 """Comprehensive, non Home Assistant specific, base publication""" 

318 if discovery.publish_policy not in (PublishPolicy.HOMEASSISTANT, PublishPolicy.MQTT): 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true

319 return 

320 self.log.debug("Discovery publish: %s", discovery) 

321 payload: dict[str, Any] = discovery.as_dict() 

322 payload["update"]["in_progress"] = in_progress # ty:ignore[invalid-assignment] 

323 self.publish(self.general_topic(discovery), payload) 

324 

325 def publish_hass_state(self, discovery: Discovery, in_progress: bool = False) -> None: 

326 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true

327 return 

328 self.log.debug("HASS State update, in progress: %s, discovery: %s", in_progress, discovery) 

329 self.publish( 

330 self.state_topic(discovery), 

331 hass_format_state( 

332 discovery, 

333 discovery.session, 

334 in_progress=in_progress, 

335 ), 

336 ) 

337 

338 def publish_hass_config(self, discovery: Discovery) -> None: 

339 if discovery.publish_policy != PublishPolicy.HOMEASSISTANT: 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true

340 return 

341 object_id = f"{discovery.source_type}_{self.node_cfg.name}_{discovery.name}" 

342 self.log.debug("HASS Config: %s", object_id) 

343 

344 self.publish( 

345 self.config_topic(discovery), 

346 hass_format_config( 

347 discovery=discovery, 

348 object_id=object_id, 

349 area=self.hass_cfg.area, 

350 state_topic=self.state_topic(discovery), 

351 attrs_topic=self.general_topic(discovery) if self.hass_cfg.extra_attributes else None, 

352 command_topic=self.command_topic(discovery.provider), 

353 force_command_topic=self.hass_cfg.force_command_topic, 

354 device_creation=self.hass_cfg.device_creation, 

355 ), 

356 ) 

357 

358 def subscribe_hass_command(self, provider: ReleaseProvider): # noqa: ANN201 

359 topic = self.command_topic(provider) 

360 if topic in self.providers_by_topic or self.client is None: 

361 self.log.debug("Skipping subscription", topic=topic) 

362 else: 

363 self.log.info("Handler subscribing", topic=topic) 

364 self.providers_by_topic[topic] = provider 

365 self.client.subscribe(topic) 

366 return topic 

367 

368 def loop_once(self) -> None: 

369 if self.client: 

370 self.client.loop() 

371 

372 def publish(self, topic: str, payload: dict, qos: int = 0, retain: bool = True) -> None: 

373 if self.client: 

374 self.client.publish(topic, payload=json.dumps(payload), qos=qos, retain=retain)