Coverage for src/nats_contrib/micro/sdk/sdk.py: 24%

116 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-27 02:04 +0100

1from __future__ import annotations 

2 

3import asyncio 

4import contextlib 

5import datetime 

6import signal 

7from typing import Any, AsyncContextManager, Awaitable, Callable, Coroutine, TypeVar 

8 

9from nats.aio.client import Client as NATS 

10from nats_contrib.connect_opts import ConnectOption, connect 

11 

12from ..api import Service, add_service 

13from .decorators import register_service, register_group 

14 

15T = TypeVar("T") 

16E = TypeVar("E") 

17 

18 

19class Context: 

20 """A class to run micro services easily. 

21 

22 This class is useful in a main function to manage ensure 

23 that all async resources are cleaned up properly when the 

24 program is cancelled. 

25 

26 It also allows to listen to signals and cancel the program 

27 when a signal is received easily. 

28 """ 

29 

30 def __init__(self, client: NATS | None = None): 

31 self.exit_stack = contextlib.AsyncExitStack() 

32 self.cancel_event = asyncio.Event() 

33 self.client = client or NATS() 

34 self.services: list[Service] = [] 

35 

36 async def connect(self, *options: ConnectOption) -> None: 

37 """Connect to the NATS server. Does not raise an error when cancelled""" 

38 await self.wait_for(connect(client=self.client, *options)) 

39 if not self.cancelled(): 

40 await self.enter(self.client) 

41 

42 async def add_service( 

43 self, 

44 name: str, 

45 version: str, 

46 description: str | None = None, 

47 metadata: dict[str, str] | None = None, 

48 queue_group: str | None = None, 

49 pending_bytes_limit_by_endpoint: int | None = None, 

50 pending_msgs_limit_by_endpoint: int | None = None, 

51 now: Callable[[], datetime.datetime] | None = None, 

52 id_generator: Callable[[], str] | None = None, 

53 api_prefix: str | None = None, 

54 ) -> Service: 

55 """Add a service to the context. 

56 

57 This will start the service using the client used 

58 to connect to the NATS server. 

59 """ 

60 service = add_service( 

61 self.client, 

62 name, 

63 version, 

64 description, 

65 metadata, 

66 queue_group, 

67 pending_bytes_limit_by_endpoint, 

68 pending_msgs_limit_by_endpoint, 

69 now, 

70 id_generator, 

71 api_prefix, 

72 ) 

73 await self.enter(service) 

74 self.services.append(service) 

75 return service 

76 

77 async def register_service( 

78 self, 

79 service: Any, 

80 prefix: str | None = None, 

81 now: Callable[[], datetime.datetime] | None = None, 

82 id_generator: Callable[[], str] | None = None, 

83 api_prefix: str | None = None, 

84 ) -> Service: 

85 """Register a service in the context. 

86 

87 This will start the service using the client used 

88 to connect to the NATS server. 

89 """ 

90 service = register_service( 

91 self.client, 

92 service, 

93 prefix, 

94 now, 

95 id_generator, 

96 api_prefix, 

97 ) 

98 await self.enter(service) 

99 self.services.append(service) 

100 return service 

101 

102 async def register_group( 

103 self, service: Service, group: Any, prefix: str | None = None 

104 ) -> None: 

105 """Register a group in the context.""" 

106 await register_group(service, group, prefix=prefix) 

107 

108 def reset(self) -> None: 

109 """Reset all the services.""" 

110 for service in self.services: 

111 service.reset() 

112 

113 def cancel(self) -> None: 

114 """Set the cancel event.""" 

115 self.cancel_event.set() 

116 

117 def cancelled(self) -> bool: 

118 """Check if the context was cancelled.""" 

119 return self.cancel_event.is_set() 

120 

121 def add_disconnected_callback( 

122 self, callback: Callable[[], Awaitable[None]] 

123 ) -> None: 

124 """Add a disconnected callback to the NATS client.""" 

125 existing = self.client._disconnected_cb # pyright: ignore[reportPrivateUsage] 

126 self.client._disconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage] 

127 existing, callback 

128 ) 

129 

130 def add_closed_callback(self, callback: Callable[[], Awaitable[None]]) -> None: 

131 """Add a closed callback to the NATS client.""" 

132 existing = self.client._closed_cb # pyright: ignore[reportPrivateUsage] 

133 self.client._closed_cb = _chain0( # pyright: ignore[reportPrivateUsage] 

134 existing, callback 

135 ) 

136 

137 def add_reconnected_callback(self, callback: Callable[[], Awaitable[None]]) -> None: 

138 """Add a reconnected callback to the NATS client.""" 

139 existing = self.client._reconnected_cb # pyright: ignore[reportPrivateUsage] 

140 self.client._reconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage] 

141 existing, callback 

142 ) 

143 

144 def add_error_callback( 

145 self, callback: Callable[[Exception], Awaitable[None]] 

146 ) -> None: 

147 """Add an error callback to the NATS client.""" 

148 existing = self.client._error_cb # pyright: ignore[reportPrivateUsage] 

149 self.client._error_cb = _chain1( # pyright: ignore[reportPrivateUsage] 

150 existing, callback 

151 ) 

152 

153 def trap_signal(self, *signals: signal.Signals) -> None: 

154 """Notify the context that a signal has been received.""" 

155 if not signals: 

156 signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM) 

157 loop = asyncio.get_event_loop() 

158 for sig in signals: 

159 loop.add_signal_handler(sig, self.cancel) 

160 

161 async def enter(self, async_context: AsyncContextManager[T]) -> T: 

162 """Enter an async context.""" 

163 return await self.exit_stack.enter_async_context(async_context) 

164 

165 async def wait(self) -> None: 

166 """Wait for the cancel event to be set.""" 

167 await self.cancel_event.wait() 

168 

169 async def wait_for(self, coro: Coroutine[Any, Any, Any]) -> None: 

170 """Run a coroutine in the context and cancel it context is cancelled. 

171 

172 This method does not raise an exception if the coroutine is cancelled. 

173 You can use .cancelled() on the context to check if the coroutine was 

174 cancelled. 

175 """ 

176 await _run_until_first_complete(coro, self.wait()) 

177 

178 async def __aenter__(self) -> "Context": 

179 await self.exit_stack.__aenter__() 

180 return self 

181 

182 async def __aexit__(self, *args: Any, **kwargs: Any) -> None: 

183 try: 

184 await self.exit_stack.__aexit__(None, None, None) 

185 finally: 

186 self.services.clear() 

187 

188 async def run_forever( 

189 self, 

190 setup: Callable[[Context], Coroutine[Any, Any, None]], 

191 *options: ConnectOption, 

192 trap_signals: bool | tuple[signal.Signals, ...] = False, 

193 ) -> None: 

194 """Useful in a main function of a program. 

195 

196 This method will first connect to the NATS server using the provided 

197 options. It will then run the setup function and finally enter any 

198 additional services provided. 

199 

200 If trap_signals is True, it will trap SIGINT and SIGTERM signals 

201 and cancel the context when one of these signals is received. 

202 

203 Other signals can be trapped by providing a tuple of signals to 

204 trap. 

205 

206 This method will not raise an exception if the context is cancelled. 

207 

208 You can use .cancelled() on the context to check if the coroutine was 

209 cancelled. 

210 

211 Warning: 

212 The context must not have been used as an async context manager 

213 before calling this method. 

214 

215 Args: 

216 setup: A coroutine to setup the program. 

217 options: The options to pass to the connect method. 

218 trap_signals: If True, trap SIGINT and SIGTERM signals. 

219 """ 

220 async with self as ctx: 

221 if trap_signals: 

222 if trap_signals is True: 

223 trap_signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM) 

224 ctx.trap_signal(*trap_signals) 

225 await ctx.wait_for(connect(client=ctx.client, *options)) 

226 if ctx.cancelled(): 

227 return 

228 await ctx.wait_for(setup(ctx)) 

229 if ctx.cancelled(): 

230 return 

231 await ctx.wait() 

232 

233 

234def _chain0( 

235 existing: Callable[[], Awaitable[None]] | None, new: Callable[[], Awaitable[None]] 

236) -> Callable[[], Awaitable[None]]: 

237 """Chain two coroutines.""" 

238 if existing is None: 

239 return new 

240 

241 async def chained() -> None: 

242 try: 

243 await new() 

244 finally: 

245 await existing() 

246 

247 return chained 

248 

249 

250def _chain1( 

251 existing: Callable[[T], Awaitable[None]] | None, new: Callable[[T], Awaitable[None]] 

252) -> Callable[[T], Awaitable[None]]: 

253 """Chain two coroutines.""" 

254 if existing is None: 

255 return new 

256 

257 async def chained(arg: T) -> None: 

258 try: 

259 await new(arg) 

260 finally: 

261 await existing(arg) 

262 

263 return chained 

264 

265 

266async def _run_until_first_complete( 

267 *coros: Coroutine[Any, Any, Any], 

268) -> None: 

269 """Run a bunch of coroutines and stop as soon as the first stops.""" 

270 tasks: list[asyncio.Task[Any]] = [asyncio.create_task(coro) for coro in coros] 

271 try: 

272 await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) 

273 finally: 

274 for task in tasks: 

275 if not task.done(): 

276 task.cancel() 

277 # Make sure all tasks are cancelled AND finished 

278 await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) 

279 # Check for exceptions 

280 for task in tasks: 

281 if task.cancelled(): 

282 continue 

283 if err := task.exception(): 

284 raise err 

285 

286 

287def run( 

288 setup: Callable[[Context], Coroutine[Any, Any, None]], 

289 *options: ConnectOption, 

290 trap_signals: bool | tuple[signal.Signals, ...] = False, 

291 client: NATS | None = None, 

292) -> None: 

293 """Helper function to run an async program.""" 

294 

295 asyncio.run( 

296 Context(client=client).run_forever( 

297 setup, 

298 *options, 

299 trap_signals=trap_signals, 

300 ) 

301 )