Coverage for src/nats_contrib/micro/sdk/decorators.py: 48%
121 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-27 05:11 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-27 05:11 +0100
1from __future__ import annotations
3import datetime
4import inspect
5from dataclasses import dataclass, field
6from typing import Any, AsyncContextManager, Callable, Iterator, TypeVar
8from nats.aio.client import Client as NATS
9from typing_extensions import dataclass_transform
11from ..api import Group, Service, add_service
12from ..request import Handler
13from ..middleware import Middleware
16S = TypeVar("S", bound=Any)
17F = TypeVar("F", bound=Callable[..., Any])
20@dataclass
21class EndpointSpec:
22 name: str
23 """An alphanumeric human-readable string used to describe the endpoint.
25 Multiple endpoints can have the same names.
26 """
28 subject: str | None = None
29 """The subject of the endpoint. When subject is not set, it defaults to the name of the endpoint."""
31 queue_group: str | None = None
32 """The queue group of the endpoint. When queue group is not set, it defaults to the queue group of the parent group or service."""
34 metadata: dict[str, str] | None = None
35 """The metadata of the endpoint."""
37 pending_msgs_limit: int | None = None
38 """The pending message limit for this endpoint."""
40 pending_bytes_limit: int | None = None
41 """The pending bytes limit for this endpoint."""
43 disabled: bool = False
44 """Whether the endpoint is disabled."""
47@dataclass
48class ServiceSpec:
50 name: str
51 """The kind of the service. Shared by all services that have the same name.
52 This name can only have A-Z, a-z, 0-9, dash, underscore."""
54 version: str
55 """The version of the service.
56 This verson must be a valid semantic version."""
58 description: str | None = None
59 """The description of the service."""
61 metadata: dict[str, str] | None = None
62 """The metadata of the service."""
64 queue_group: str | None = None
65 """The default queue group of the service."""
67 pending_msgs_limit_by_endpoint: int | None = None
68 """The default pending messages limit of the service.
70 This limit is applied BY subject.
71 """
73 pending_bytes_limit_by_endpoint: int | None = None
74 """The default pending bytes limit of the service.
76 This limit is applied BY subject.
77 """
80@dataclass
81class GroupSpec:
82 name: str
83 """An alphanumeric human-readable string used to describe the group.
85 Multiple groups can have the same names.
86 """
88 queue_group: str | None = None
89 """The queue group of the group. When queue group is not set, it defaults to the queue group of the parent group or service."""
91 pending_msgs_limit: int | None = None
92 """The pending message limit for this group."""
94 pending_bytes_limit: int | None = None
95 """The pending bytes limit for this group."""
98@dataclass_transform(field_specifiers=(field,))
99def service(
100 name: str,
101 version: str,
102 description: str | None = None,
103 metadata: dict[str, str] | None = None,
104 queue_group: str | None = None,
105 pending_msgs_limit_by_endpoint: int | None = None,
106 pending_bytes_limit_by_endpoint: int | None = None,
107) -> Callable[[type[S]], type[S]]:
108 """ "A decorator to define a micro service."""
110 def func(cls: type[S]) -> type[S]:
111 spec = ServiceSpec(
112 name=name,
113 version=version,
114 description=description,
115 metadata=metadata,
116 queue_group=queue_group,
117 pending_msgs_limit_by_endpoint=pending_msgs_limit_by_endpoint,
118 pending_bytes_limit_by_endpoint=pending_bytes_limit_by_endpoint,
119 )
120 dc = dataclass()(cls)
121 dc.__service_spec__ = spec
122 return cls
124 return func
127@dataclass_transform(field_specifiers=(field,))
128def group(
129 name: str,
130 queue_group: str | None = None,
131 pending_msgs_limit_by_endpoint: int | None = None,
132 pending_bytes_limit_by_endpoint: int | None = None,
133) -> Callable[[type[S]], type[S]]:
134 """ "A decorator to define a micro service group."""
136 def func(cls: type[S]) -> type[S]:
137 spec = GroupSpec(
138 name=name,
139 queue_group=queue_group,
140 pending_msgs_limit=pending_msgs_limit_by_endpoint,
141 pending_bytes_limit=pending_bytes_limit_by_endpoint,
142 )
143 dc = dataclass()(cls)
144 dc.__group_spec__ = spec
145 return cls
147 return func
150def endpoint(
151 name: str | None = None,
152 subject: str | None = None,
153 queue_group: str | None = None,
154 pending_msgs_limit: int | None = None,
155 pending_bytes_limit: int | None = None,
156 disabled: bool = False,
157) -> Callable[[F], F]:
158 """A decorator to define an endpoint."""
160 def func(f: F) -> F:
162 spec = EndpointSpec(
163 name=name or f.__name__,
164 subject=subject,
165 queue_group=queue_group,
166 metadata=None,
167 pending_msgs_limit=pending_msgs_limit,
168 pending_bytes_limit=pending_bytes_limit,
169 disabled=disabled,
170 )
171 setattr(f, "__endpoint_spec__", spec)
172 return f
174 return func
177def register_service(
178 client: NATS,
179 service: Any,
180 prefix: str | None = None,
181 now: Callable[[], datetime.datetime] | None = None,
182 id_generator: Callable[[], str] | None = None,
183 api_prefix: str | None = None,
184 middlewares: list[Middleware] | None = None,
185) -> AsyncContextManager[Service]:
186 class ServiceMounter:
187 def __init__(self) -> None:
188 self.service: Service | None = None
190 async def __aenter__(self) -> Service:
191 # Get service spec
192 service_spec = get_service_spec(service)
193 # Iterate over endpoints
194 micro_service = add_service(
195 client,
196 service_spec.name,
197 service_spec.version,
198 service_spec.description,
199 service_spec.metadata,
200 service_spec.queue_group,
201 service_spec.pending_bytes_limit_by_endpoint,
202 service_spec.pending_msgs_limit_by_endpoint,
203 now=now,
204 id_generator=id_generator,
205 api_prefix=api_prefix,
206 )
207 await micro_service.start()
208 self.service = micro_service
209 parent: Group | Service
210 if prefix:
211 parent = micro_service.add_group(prefix)
212 else:
213 parent = micro_service
214 for endpoint_handler, endpoint_spec in get_endpoints_specs(service):
215 if endpoint_spec.disabled:
216 continue
217 await parent.add_endpoint(
218 name=endpoint_spec.name,
219 handler=endpoint_handler,
220 subject=endpoint_spec.subject,
221 queue_group=endpoint_spec.queue_group,
222 pending_msgs_limit=endpoint_spec.pending_msgs_limit,
223 pending_bytes_limit=endpoint_spec.pending_bytes_limit,
224 middlewares=middlewares,
225 )
226 return micro_service
228 async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
229 if self.service:
230 await self.service.stop()
232 return ServiceMounter()
235async def register_group(
236 service: Service,
237 group: Any,
238 prefix: str | None = None,
239 middlewares: list[Middleware] | None = None,
240) -> None:
242 group_spec = get_group_spec(group)
243 parent: Group | Service
244 if prefix:
245 parent = service.add_group(prefix)
246 else:
247 parent = service
248 parent_group = parent.add_group(
249 name=group_spec.name,
250 queue_group=group_spec.queue_group,
251 pending_msgs_limit_by_endpoint=group_spec.pending_msgs_limit,
252 pending_bytes_limit_by_endpoint=group_spec.pending_bytes_limit,
253 )
254 for endpoint_handler, endpoint_spec in get_endpoints_specs(group):
255 if endpoint_spec.disabled:
256 continue
257 await parent_group.add_endpoint(
258 name=endpoint_spec.name,
259 handler=endpoint_handler,
260 subject=endpoint_spec.subject,
261 queue_group=endpoint_spec.queue_group,
262 pending_msgs_limit=endpoint_spec.pending_msgs_limit,
263 pending_bytes_limit=endpoint_spec.pending_bytes_limit,
264 middlewares=middlewares,
265 )
268def get_service_spec(instance: object) -> ServiceSpec:
269 try:
270 return instance.__service_spec__ # type: ignore
271 except AttributeError:
272 raise TypeError("ServiceRouter must be decorated with @service")
275def get_group_spec(instance: object) -> GroupSpec:
276 try:
277 return instance.__group_spec__ # type: ignore
278 except AttributeError:
279 raise TypeError("Group must be decorated with @group")
282def get_endpoints_specs(instance: object) -> Iterator[tuple[Handler, EndpointSpec]]:
283 for _, member in inspect.getmembers(instance):
284 if hasattr(member, "__endpoint_spec__"):
285 yield member, member.__endpoint_spec__