Coverage for src/nats_contrib/micro/context.py: 23%
113 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-06 11:09 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-06 11:09 +0100
1from __future__ import annotations
3import asyncio
4import contextlib
5import datetime
6import inspect
7import signal
8from typing import Any, AsyncContextManager, Awaitable, Callable, Coroutine, TypeVar
10from nats.aio.client import Client as NATS
11from nats_contrib.connect_opts import ConnectOption, connect
13from .api import Service, add_service
15T = TypeVar("T")
16E = TypeVar("E")
19class Context:
20 """A class to run micro services easily.
22 This class is useful in a main function to manage ensure
23 that all async resources are cleaned up properly when the
24 program is cancelled.
26 It also allows to listen to signals and cancel the program
27 when a signal is received easily.
28 """
30 def __init__(self, client: NATS | None = None):
31 self.exit_stack = contextlib.AsyncExitStack()
32 self.cancel_event = asyncio.Event()
33 self.client = client or NATS()
34 self.services: list[Service] = []
36 async def connect(self, *options: ConnectOption) -> None:
37 """Connect to the NATS server. Does not raise an error when cancelled"""
38 await self.wait_for(connect(client=self.client, *options))
39 if not self.cancelled():
40 await self.enter(self.client)
42 async def add_service(
43 self,
44 name: str,
45 version: str,
46 description: str | None = None,
47 metadata: dict[str, str] | None = None,
48 queue_group: str | None = None,
49 pending_bytes_limit_by_endpoint: int | None = None,
50 pending_msgs_limit_by_endpoint: int | None = None,
51 now: Callable[[], datetime.datetime] | None = None,
52 id_generator: Callable[[], str] | None = None,
53 api_prefix: str | None = None,
54 ) -> Service:
55 """Add a service to the context.
57 This will start the service using the client used
58 to connect to the NATS server.
59 """
60 service = add_service(
61 self.client,
62 name,
63 version,
64 description,
65 metadata,
66 queue_group,
67 pending_bytes_limit_by_endpoint,
68 pending_msgs_limit_by_endpoint,
69 now,
70 id_generator,
71 api_prefix,
72 )
73 await self.enter(service)
74 self.services.append(service)
75 return service
77 def reset(self) -> None:
78 """Reset all the services."""
79 for service in self.services:
80 service.reset()
82 def cancel(self) -> None:
83 """Set the cancel event."""
84 self.cancel_event.set()
86 def cancelled(self) -> bool:
87 """Check if the context was cancelled."""
88 return self.cancel_event.is_set()
90 def add_disconnected_callback(
91 self, callback: Callable[[], Awaitable[None]]
92 ) -> None:
93 """Add a disconnected callback to the NATS client."""
94 existing = self.client._disconnected_cb # pyright: ignore[reportPrivateUsage]
95 self.client._disconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage]
96 existing, callback
97 )
99 def add_closed_callback(self, callback: Callable[[], Awaitable[None]]) -> None:
100 """Add a closed callback to the NATS client."""
101 existing = self.client._closed_cb # pyright: ignore[reportPrivateUsage]
102 self.client._closed_cb = _chain0( # pyright: ignore[reportPrivateUsage]
103 existing, callback
104 )
106 def add_reconnected_callback(self, callback: Callable[[], Awaitable[None]]) -> None:
107 """Add a reconnected callback to the NATS client."""
108 existing = self.client._reconnected_cb # pyright: ignore[reportPrivateUsage]
109 self.client._reconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage]
110 existing, callback
111 )
113 def add_error_callback(
114 self, callback: Callable[[Exception], Awaitable[None]]
115 ) -> None:
116 """Add an error callback to the NATS client."""
117 existing = self.client._error_cb # pyright: ignore[reportPrivateUsage]
118 self.client._error_cb = _chain1( # pyright: ignore[reportPrivateUsage]
119 existing, callback
120 )
122 def trap_signal(self, *signals: signal.Signals) -> None:
123 """Notify the context that a signal has been received."""
124 if not signals:
125 signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM)
126 loop = asyncio.get_event_loop()
127 for sig in signals:
128 loop.add_signal_handler(sig, self.cancel)
130 def push(self, callback: Callable[[], Awaitable[None] | None]) -> None:
131 """Add a callback to the exit stack."""
132 if inspect.iscoroutinefunction(callback):
133 self.exit_stack.push_async_callback(callback)
134 else:
135 self.exit_stack.callback(callback)
137 async def enter(self, async_context: AsyncContextManager[T]) -> T:
138 """Enter an async context."""
139 return await self.exit_stack.enter_async_context(async_context)
141 async def wait(self) -> None:
142 """Wait for the cancel event to be set."""
143 await self.cancel_event.wait()
145 async def wait_for(self, coro: Coroutine[Any, Any, Any]) -> None:
146 """Run a coroutine in the context and cancel it context is cancelled.
148 This method does not raise an exception if the coroutine is cancelled.
149 You can use .cancelled() on the context to check if the coroutine was
150 cancelled.
151 """
152 await _run_until_first_complete(coro, self.wait())
154 async def __aenter__(self) -> "Context":
155 await self.exit_stack.__aenter__()
156 return self
158 async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
159 try:
160 await self.exit_stack.__aexit__(None, None, None)
161 finally:
162 self.services.clear()
164 async def run_forever(
165 self,
166 setup: Callable[[Context], Coroutine[Any, Any, None]],
167 /,
168 *options: ConnectOption,
169 trap_signals: bool | tuple[signal.Signals, ...] = False,
170 ) -> None:
171 """Useful in a main function of a program.
173 This method will first connect to the NATS server using the provided
174 options. It will then run the setup function and finally enter any
175 additional services provided.
177 If trap_signals is True, it will trap SIGINT and SIGTERM signals
178 and cancel the context when one of these signals is received.
180 Other signals can be trapped by providing a tuple of signals to
181 trap.
183 This method will not raise an exception if the context is cancelled.
185 You can use .cancelled() on the context to check if the coroutine was
186 cancelled.
188 Warning:
189 The context must not have been used as an async context manager
190 before calling this method.
192 Args:
193 setup: A coroutine to setup the program.
194 options: The options to pass to the connect method.
195 trap_signals: If True, trap SIGINT and SIGTERM signals.
196 """
197 async with self as ctx:
198 if trap_signals:
199 if trap_signals is True:
200 trap_signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM)
201 ctx.trap_signal(*trap_signals)
202 await ctx.wait_for(connect(client=ctx.client, *options))
203 if ctx.cancelled():
204 return
205 await ctx.wait_for(setup(ctx))
206 if ctx.cancelled():
207 return
208 await ctx.wait()
211def _chain0(
212 existing: Callable[[], Awaitable[None]] | None, new: Callable[[], Awaitable[None]]
213) -> Callable[[], Awaitable[None]]:
214 """Chain two coroutines."""
215 if existing is None:
216 return new
218 async def chained() -> None:
219 try:
220 await new()
221 finally:
222 await existing()
224 return chained
227def _chain1(
228 existing: Callable[[T], Awaitable[None]] | None, new: Callable[[T], Awaitable[None]]
229) -> Callable[[T], Awaitable[None]]:
230 """Chain two coroutines."""
231 if existing is None:
232 return new
234 async def chained(arg: T) -> None:
235 try:
236 await new(arg)
237 finally:
238 await existing(arg)
240 return chained
243async def _run_until_first_complete(
244 *coros: Coroutine[Any, Any, Any],
245) -> None:
246 """Run a bunch of coroutines and stop as soon as the first stops."""
247 tasks: list[asyncio.Task[Any]] = [asyncio.create_task(coro) for coro in coros]
248 try:
249 await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
250 finally:
251 for task in tasks:
252 if not task.done():
253 task.cancel()
254 # Make sure all tasks are cancelled AND finished
255 await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
256 # Check for exceptions
257 for task in tasks:
258 if task.cancelled():
259 continue
260 if err := task.exception():
261 raise err
264def run(
265 setup: Callable[[Context], Coroutine[Any, Any, None]],
266 /,
267 *options: ConnectOption,
268 trap_signals: bool | tuple[signal.Signals, ...] = False,
269 client: NATS | None = None,
270) -> None:
271 """Helper function to run an async program."""
273 asyncio.run(
274 Context(client=client).run_forever(
275 setup,
276 *options,
277 trap_signals=trap_signals,
278 )
279 )