Coverage for src/nats_contrib/micro/context.py: 23%

113 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-06 11:09 +0100

1from __future__ import annotations 

2 

3import asyncio 

4import contextlib 

5import datetime 

6import inspect 

7import signal 

8from typing import Any, AsyncContextManager, Awaitable, Callable, Coroutine, TypeVar 

9 

10from nats.aio.client import Client as NATS 

11from nats_contrib.connect_opts import ConnectOption, connect 

12 

13from .api import Service, add_service 

14 

15T = TypeVar("T") 

16E = TypeVar("E") 

17 

18 

19class Context: 

20 """A class to run micro services easily. 

21 

22 This class is useful in a main function to manage ensure 

23 that all async resources are cleaned up properly when the 

24 program is cancelled. 

25 

26 It also allows to listen to signals and cancel the program 

27 when a signal is received easily. 

28 """ 

29 

30 def __init__(self, client: NATS | None = None): 

31 self.exit_stack = contextlib.AsyncExitStack() 

32 self.cancel_event = asyncio.Event() 

33 self.client = client or NATS() 

34 self.services: list[Service] = [] 

35 

36 async def connect(self, *options: ConnectOption) -> None: 

37 """Connect to the NATS server. Does not raise an error when cancelled""" 

38 await self.wait_for(connect(client=self.client, *options)) 

39 if not self.cancelled(): 

40 await self.enter(self.client) 

41 

42 async def add_service( 

43 self, 

44 name: str, 

45 version: str, 

46 description: str | None = None, 

47 metadata: dict[str, str] | None = None, 

48 queue_group: str | None = None, 

49 pending_bytes_limit_by_endpoint: int | None = None, 

50 pending_msgs_limit_by_endpoint: int | None = None, 

51 now: Callable[[], datetime.datetime] | None = None, 

52 id_generator: Callable[[], str] | None = None, 

53 api_prefix: str | None = None, 

54 ) -> Service: 

55 """Add a service to the context. 

56 

57 This will start the service using the client used 

58 to connect to the NATS server. 

59 """ 

60 service = add_service( 

61 self.client, 

62 name, 

63 version, 

64 description, 

65 metadata, 

66 queue_group, 

67 pending_bytes_limit_by_endpoint, 

68 pending_msgs_limit_by_endpoint, 

69 now, 

70 id_generator, 

71 api_prefix, 

72 ) 

73 await self.enter(service) 

74 self.services.append(service) 

75 return service 

76 

77 def reset(self) -> None: 

78 """Reset all the services.""" 

79 for service in self.services: 

80 service.reset() 

81 

82 def cancel(self) -> None: 

83 """Set the cancel event.""" 

84 self.cancel_event.set() 

85 

86 def cancelled(self) -> bool: 

87 """Check if the context was cancelled.""" 

88 return self.cancel_event.is_set() 

89 

90 def add_disconnected_callback( 

91 self, callback: Callable[[], Awaitable[None]] 

92 ) -> None: 

93 """Add a disconnected callback to the NATS client.""" 

94 existing = self.client._disconnected_cb # pyright: ignore[reportPrivateUsage] 

95 self.client._disconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage] 

96 existing, callback 

97 ) 

98 

99 def add_closed_callback(self, callback: Callable[[], Awaitable[None]]) -> None: 

100 """Add a closed callback to the NATS client.""" 

101 existing = self.client._closed_cb # pyright: ignore[reportPrivateUsage] 

102 self.client._closed_cb = _chain0( # pyright: ignore[reportPrivateUsage] 

103 existing, callback 

104 ) 

105 

106 def add_reconnected_callback(self, callback: Callable[[], Awaitable[None]]) -> None: 

107 """Add a reconnected callback to the NATS client.""" 

108 existing = self.client._reconnected_cb # pyright: ignore[reportPrivateUsage] 

109 self.client._reconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage] 

110 existing, callback 

111 ) 

112 

113 def add_error_callback( 

114 self, callback: Callable[[Exception], Awaitable[None]] 

115 ) -> None: 

116 """Add an error callback to the NATS client.""" 

117 existing = self.client._error_cb # pyright: ignore[reportPrivateUsage] 

118 self.client._error_cb = _chain1( # pyright: ignore[reportPrivateUsage] 

119 existing, callback 

120 ) 

121 

122 def trap_signal(self, *signals: signal.Signals) -> None: 

123 """Notify the context that a signal has been received.""" 

124 if not signals: 

125 signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM) 

126 loop = asyncio.get_event_loop() 

127 for sig in signals: 

128 loop.add_signal_handler(sig, self.cancel) 

129 

130 def push(self, callback: Callable[[], Awaitable[None] | None]) -> None: 

131 """Add a callback to the exit stack.""" 

132 if inspect.iscoroutinefunction(callback): 

133 self.exit_stack.push_async_callback(callback) 

134 else: 

135 self.exit_stack.callback(callback) 

136 

137 async def enter(self, async_context: AsyncContextManager[T]) -> T: 

138 """Enter an async context.""" 

139 return await self.exit_stack.enter_async_context(async_context) 

140 

141 async def wait(self) -> None: 

142 """Wait for the cancel event to be set.""" 

143 await self.cancel_event.wait() 

144 

145 async def wait_for(self, coro: Coroutine[Any, Any, Any]) -> None: 

146 """Run a coroutine in the context and cancel it context is cancelled. 

147 

148 This method does not raise an exception if the coroutine is cancelled. 

149 You can use .cancelled() on the context to check if the coroutine was 

150 cancelled. 

151 """ 

152 await _run_until_first_complete(coro, self.wait()) 

153 

154 async def __aenter__(self) -> "Context": 

155 await self.exit_stack.__aenter__() 

156 return self 

157 

158 async def __aexit__(self, *args: Any, **kwargs: Any) -> None: 

159 try: 

160 await self.exit_stack.__aexit__(None, None, None) 

161 finally: 

162 self.services.clear() 

163 

164 async def run_forever( 

165 self, 

166 setup: Callable[[Context], Coroutine[Any, Any, None]], 

167 /, 

168 *options: ConnectOption, 

169 trap_signals: bool | tuple[signal.Signals, ...] = False, 

170 ) -> None: 

171 """Useful in a main function of a program. 

172 

173 This method will first connect to the NATS server using the provided 

174 options. It will then run the setup function and finally enter any 

175 additional services provided. 

176 

177 If trap_signals is True, it will trap SIGINT and SIGTERM signals 

178 and cancel the context when one of these signals is received. 

179 

180 Other signals can be trapped by providing a tuple of signals to 

181 trap. 

182 

183 This method will not raise an exception if the context is cancelled. 

184 

185 You can use .cancelled() on the context to check if the coroutine was 

186 cancelled. 

187 

188 Warning: 

189 The context must not have been used as an async context manager 

190 before calling this method. 

191 

192 Args: 

193 setup: A coroutine to setup the program. 

194 options: The options to pass to the connect method. 

195 trap_signals: If True, trap SIGINT and SIGTERM signals. 

196 """ 

197 async with self as ctx: 

198 if trap_signals: 

199 if trap_signals is True: 

200 trap_signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM) 

201 ctx.trap_signal(*trap_signals) 

202 await ctx.wait_for(connect(client=ctx.client, *options)) 

203 if ctx.cancelled(): 

204 return 

205 await ctx.wait_for(setup(ctx)) 

206 if ctx.cancelled(): 

207 return 

208 await ctx.wait() 

209 

210 

211def _chain0( 

212 existing: Callable[[], Awaitable[None]] | None, new: Callable[[], Awaitable[None]] 

213) -> Callable[[], Awaitable[None]]: 

214 """Chain two coroutines.""" 

215 if existing is None: 

216 return new 

217 

218 async def chained() -> None: 

219 try: 

220 await new() 

221 finally: 

222 await existing() 

223 

224 return chained 

225 

226 

227def _chain1( 

228 existing: Callable[[T], Awaitable[None]] | None, new: Callable[[T], Awaitable[None]] 

229) -> Callable[[T], Awaitable[None]]: 

230 """Chain two coroutines.""" 

231 if existing is None: 

232 return new 

233 

234 async def chained(arg: T) -> None: 

235 try: 

236 await new(arg) 

237 finally: 

238 await existing(arg) 

239 

240 return chained 

241 

242 

243async def _run_until_first_complete( 

244 *coros: Coroutine[Any, Any, Any], 

245) -> None: 

246 """Run a bunch of coroutines and stop as soon as the first stops.""" 

247 tasks: list[asyncio.Task[Any]] = [asyncio.create_task(coro) for coro in coros] 

248 try: 

249 await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) 

250 finally: 

251 for task in tasks: 

252 if not task.done(): 

253 task.cancel() 

254 # Make sure all tasks are cancelled AND finished 

255 await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) 

256 # Check for exceptions 

257 for task in tasks: 

258 if task.cancelled(): 

259 continue 

260 if err := task.exception(): 

261 raise err 

262 

263 

264def run( 

265 setup: Callable[[Context], Coroutine[Any, Any, None]], 

266 /, 

267 *options: ConnectOption, 

268 trap_signals: bool | tuple[signal.Signals, ...] = False, 

269 client: NATS | None = None, 

270) -> None: 

271 """Helper function to run an async program.""" 

272 

273 asyncio.run( 

274 Context(client=client).run_forever( 

275 setup, 

276 *options, 

277 trap_signals=trap_signals, 

278 ) 

279 )