Coverage for src/nats_contrib/micro/middleware.py: 38%
74 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-06 11:09 +0100
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-06 11:09 +0100
1from __future__ import annotations
3from typing import Awaitable, Callable
5from typing_extensions import TypeAlias
7from .request import Handler, Request
9NextHandler: TypeAlias = Callable[[Request], Awaitable["Response"]]
10"""NextHandler is a type alias for the next handler in a chain of middlewares."""
12Middleware: TypeAlias = Callable[[Request, NextHandler], Awaitable["Response"]]
13"""Middleware is a type alias for a middleware function."""
16class Response:
17 """Response holds the response data and headers as well as the original request.
19 In order to update the response data or headers, use the methods provided by this class.
20 """
22 __slots__ = ["origin", "_data", "_headers"]
24 def __init__(self, origin: Request, data: bytes, headers: dict[str, str]):
25 self.origin = origin
26 self._data = data
27 self._headers = headers
29 def data(self) -> bytes:
30 """Get the response data."""
31 return self._data
33 def headers(self) -> dict[str, str]:
34 """Get the response headers."""
35 return self._headers
37 def add_header(self, key: str, value: str) -> None:
38 """Add a header to the response."""
39 self._headers[key] = value
41 def remove_header(self, key: str) -> None:
42 """Remove a header from the response."""
43 self._headers.pop(key, None)
45 def update_headers(self, headers: dict[str, str]) -> None:
46 """Update the response headers."""
47 self._headers.update(headers)
49 def clear_headers(self) -> None:
50 """Clear the response headers."""
51 self._headers.clear()
53 def set_data(self, data: bytes) -> None:
54 """Set the response data."""
55 self._data = data
57 def clear_data(self) -> None:
58 """Clear the response data."""
59 self._data = b""
62def apply_middlewares(handler: Handler, middlewares: list[Middleware]) -> Handler:
63 """Apply a list of middlewares to a handler."""
64 if not middlewares:
65 return handler
66 chained = _create_next_handler(handler)
67 chained = _apply_middlewares_to_next_handler(chained, middlewares)
68 return _create_final_handler(chained)
71def _create_next_handler(handler: Handler) -> NextHandler:
72 async def forward(request: Request) -> Response:
73 req = _CapturedRequest(request)
74 await handler(req)
75 return req.get_response()
77 return forward
80def _create_final_handler(forward: NextHandler) -> Handler:
81 async def unwrap(request: Request) -> None:
82 response = await forward(request)
83 await response.origin.respond(response.data(), response.headers())
85 return unwrap
88def _apply_middlewares_to_next_handler(
89 handler: NextHandler, middlewares: list[Middleware]
90) -> NextHandler:
91 """Apply a list of middlewares to a handler."""
92 if not middlewares:
93 return handler
94 chained = handler
95 for middleware in reversed(middlewares):
96 chained = _chain_next_handler_and_middleware(chained, middleware)
97 return chained
100def _chain_next_handler_and_middleware(
101 handler: NextHandler, middleware: Middleware
102) -> NextHandler:
103 """Chain a middleware to a handler."""
105 async def forward(request: Request) -> Response:
106 return await middleware(request, handler)
108 return forward
111class _CapturedRequest(Request):
112 def __init__(self, request: Request):
113 self._request = request
114 self._response: Response | None = None
116 def subject(self) -> str:
117 return self._request.subject()
119 def headers(self) -> dict[str, str]:
120 return self._request.headers()
122 def data(self) -> bytes:
123 return self._request.data()
125 async def respond(self, data: bytes, headers: dict[str, str] | None = None) -> None:
126 self._response = Response(self._request, data, headers or {})
128 def get_response(self) -> Response:
129 if self._response is None:
130 raise ValueError("No response was set")
131 return self._response