Coverage for src/nats_contrib/micro/sdk/sdk.py: 24%
116 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-27 02:04 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-27 02:04 +0100
1from __future__ import annotations
3import asyncio
4import contextlib
5import datetime
6import signal
7from typing import Any, AsyncContextManager, Awaitable, Callable, Coroutine, TypeVar
9from nats.aio.client import Client as NATS
10from nats_contrib.connect_opts import ConnectOption, connect
12from ..api import Service, add_service
13from .decorators import register_service, register_group
15T = TypeVar("T")
16E = TypeVar("E")
19class Context:
20 """A class to run micro services easily.
22 This class is useful in a main function to manage ensure
23 that all async resources are cleaned up properly when the
24 program is cancelled.
26 It also allows to listen to signals and cancel the program
27 when a signal is received easily.
28 """
30 def __init__(self, client: NATS | None = None):
31 self.exit_stack = contextlib.AsyncExitStack()
32 self.cancel_event = asyncio.Event()
33 self.client = client or NATS()
34 self.services: list[Service] = []
36 async def connect(self, *options: ConnectOption) -> None:
37 """Connect to the NATS server. Does not raise an error when cancelled"""
38 await self.wait_for(connect(client=self.client, *options))
39 if not self.cancelled():
40 await self.enter(self.client)
42 async def add_service(
43 self,
44 name: str,
45 version: str,
46 description: str | None = None,
47 metadata: dict[str, str] | None = None,
48 queue_group: str | None = None,
49 pending_bytes_limit_by_endpoint: int | None = None,
50 pending_msgs_limit_by_endpoint: int | None = None,
51 now: Callable[[], datetime.datetime] | None = None,
52 id_generator: Callable[[], str] | None = None,
53 api_prefix: str | None = None,
54 ) -> Service:
55 """Add a service to the context.
57 This will start the service using the client used
58 to connect to the NATS server.
59 """
60 service = add_service(
61 self.client,
62 name,
63 version,
64 description,
65 metadata,
66 queue_group,
67 pending_bytes_limit_by_endpoint,
68 pending_msgs_limit_by_endpoint,
69 now,
70 id_generator,
71 api_prefix,
72 )
73 await self.enter(service)
74 self.services.append(service)
75 return service
77 async def register_service(
78 self,
79 service: Any,
80 prefix: str | None = None,
81 now: Callable[[], datetime.datetime] | None = None,
82 id_generator: Callable[[], str] | None = None,
83 api_prefix: str | None = None,
84 ) -> Service:
85 """Register a service in the context.
87 This will start the service using the client used
88 to connect to the NATS server.
89 """
90 service = register_service(
91 self.client,
92 service,
93 prefix,
94 now,
95 id_generator,
96 api_prefix,
97 )
98 await self.enter(service)
99 self.services.append(service)
100 return service
102 async def register_group(
103 self, service: Service, group: Any, prefix: str | None = None
104 ) -> None:
105 """Register a group in the context."""
106 await register_group(service, group, prefix=prefix)
108 def reset(self) -> None:
109 """Reset all the services."""
110 for service in self.services:
111 service.reset()
113 def cancel(self) -> None:
114 """Set the cancel event."""
115 self.cancel_event.set()
117 def cancelled(self) -> bool:
118 """Check if the context was cancelled."""
119 return self.cancel_event.is_set()
121 def add_disconnected_callback(
122 self, callback: Callable[[], Awaitable[None]]
123 ) -> None:
124 """Add a disconnected callback to the NATS client."""
125 existing = self.client._disconnected_cb # pyright: ignore[reportPrivateUsage]
126 self.client._disconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage]
127 existing, callback
128 )
130 def add_closed_callback(self, callback: Callable[[], Awaitable[None]]) -> None:
131 """Add a closed callback to the NATS client."""
132 existing = self.client._closed_cb # pyright: ignore[reportPrivateUsage]
133 self.client._closed_cb = _chain0( # pyright: ignore[reportPrivateUsage]
134 existing, callback
135 )
137 def add_reconnected_callback(self, callback: Callable[[], Awaitable[None]]) -> None:
138 """Add a reconnected callback to the NATS client."""
139 existing = self.client._reconnected_cb # pyright: ignore[reportPrivateUsage]
140 self.client._reconnected_cb = _chain0( # pyright: ignore[reportPrivateUsage]
141 existing, callback
142 )
144 def add_error_callback(
145 self, callback: Callable[[Exception], Awaitable[None]]
146 ) -> None:
147 """Add an error callback to the NATS client."""
148 existing = self.client._error_cb # pyright: ignore[reportPrivateUsage]
149 self.client._error_cb = _chain1( # pyright: ignore[reportPrivateUsage]
150 existing, callback
151 )
153 def trap_signal(self, *signals: signal.Signals) -> None:
154 """Notify the context that a signal has been received."""
155 if not signals:
156 signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM)
157 loop = asyncio.get_event_loop()
158 for sig in signals:
159 loop.add_signal_handler(sig, self.cancel)
161 async def enter(self, async_context: AsyncContextManager[T]) -> T:
162 """Enter an async context."""
163 return await self.exit_stack.enter_async_context(async_context)
165 async def wait(self) -> None:
166 """Wait for the cancel event to be set."""
167 await self.cancel_event.wait()
169 async def wait_for(self, coro: Coroutine[Any, Any, Any]) -> None:
170 """Run a coroutine in the context and cancel it context is cancelled.
172 This method does not raise an exception if the coroutine is cancelled.
173 You can use .cancelled() on the context to check if the coroutine was
174 cancelled.
175 """
176 await _run_until_first_complete(coro, self.wait())
178 async def __aenter__(self) -> "Context":
179 await self.exit_stack.__aenter__()
180 return self
182 async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
183 try:
184 await self.exit_stack.__aexit__(None, None, None)
185 finally:
186 self.services.clear()
188 async def run_forever(
189 self,
190 setup: Callable[[Context], Coroutine[Any, Any, None]],
191 *options: ConnectOption,
192 trap_signals: bool | tuple[signal.Signals, ...] = False,
193 ) -> None:
194 """Useful in a main function of a program.
196 This method will first connect to the NATS server using the provided
197 options. It will then run the setup function and finally enter any
198 additional services provided.
200 If trap_signals is True, it will trap SIGINT and SIGTERM signals
201 and cancel the context when one of these signals is received.
203 Other signals can be trapped by providing a tuple of signals to
204 trap.
206 This method will not raise an exception if the context is cancelled.
208 You can use .cancelled() on the context to check if the coroutine was
209 cancelled.
211 Warning:
212 The context must not have been used as an async context manager
213 before calling this method.
215 Args:
216 setup: A coroutine to setup the program.
217 options: The options to pass to the connect method.
218 trap_signals: If True, trap SIGINT and SIGTERM signals.
219 """
220 async with self as ctx:
221 if trap_signals:
222 if trap_signals is True:
223 trap_signals = (signal.Signals.SIGINT, signal.Signals.SIGTERM)
224 ctx.trap_signal(*trap_signals)
225 await ctx.wait_for(connect(client=ctx.client, *options))
226 if ctx.cancelled():
227 return
228 await ctx.wait_for(setup(ctx))
229 if ctx.cancelled():
230 return
231 await ctx.wait()
234def _chain0(
235 existing: Callable[[], Awaitable[None]] | None, new: Callable[[], Awaitable[None]]
236) -> Callable[[], Awaitable[None]]:
237 """Chain two coroutines."""
238 if existing is None:
239 return new
241 async def chained() -> None:
242 try:
243 await new()
244 finally:
245 await existing()
247 return chained
250def _chain1(
251 existing: Callable[[T], Awaitable[None]] | None, new: Callable[[T], Awaitable[None]]
252) -> Callable[[T], Awaitable[None]]:
253 """Chain two coroutines."""
254 if existing is None:
255 return new
257 async def chained(arg: T) -> None:
258 try:
259 await new(arg)
260 finally:
261 await existing(arg)
263 return chained
266async def _run_until_first_complete(
267 *coros: Coroutine[Any, Any, Any],
268) -> None:
269 """Run a bunch of coroutines and stop as soon as the first stops."""
270 tasks: list[asyncio.Task[Any]] = [asyncio.create_task(coro) for coro in coros]
271 try:
272 await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
273 finally:
274 for task in tasks:
275 if not task.done():
276 task.cancel()
277 # Make sure all tasks are cancelled AND finished
278 await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
279 # Check for exceptions
280 for task in tasks:
281 if task.cancelled():
282 continue
283 if err := task.exception():
284 raise err
287def run(
288 setup: Callable[[Context], Coroutine[Any, Any, None]],
289 *options: ConnectOption,
290 trap_signals: bool | tuple[signal.Signals, ...] = False,
291 client: NATS | None = None,
292) -> None:
293 """Helper function to run an async program."""
295 asyncio.run(
296 Context(client=client).run_forever(
297 setup,
298 *options,
299 trap_signals=trap_signals,
300 )
301 )