Coverage for src/nats_contrib/micro/middleware.py: 38%

74 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-06 11:09 +0100

1from __future__ import annotations 

2 

3from typing import Awaitable, Callable 

4 

5from typing_extensions import TypeAlias 

6 

7from .request import Handler, Request 

8 

9NextHandler: TypeAlias = Callable[[Request], Awaitable["Response"]] 

10"""NextHandler is a type alias for the next handler in a chain of middlewares.""" 

11 

12Middleware: TypeAlias = Callable[[Request, NextHandler], Awaitable["Response"]] 

13"""Middleware is a type alias for a middleware function.""" 

14 

15 

16class Response: 

17 """Response holds the response data and headers as well as the original request. 

18 

19 In order to update the response data or headers, use the methods provided by this class. 

20 """ 

21 

22 __slots__ = ["origin", "_data", "_headers"] 

23 

24 def __init__(self, origin: Request, data: bytes, headers: dict[str, str]): 

25 self.origin = origin 

26 self._data = data 

27 self._headers = headers 

28 

29 def data(self) -> bytes: 

30 """Get the response data.""" 

31 return self._data 

32 

33 def headers(self) -> dict[str, str]: 

34 """Get the response headers.""" 

35 return self._headers 

36 

37 def add_header(self, key: str, value: str) -> None: 

38 """Add a header to the response.""" 

39 self._headers[key] = value 

40 

41 def remove_header(self, key: str) -> None: 

42 """Remove a header from the response.""" 

43 self._headers.pop(key, None) 

44 

45 def update_headers(self, headers: dict[str, str]) -> None: 

46 """Update the response headers.""" 

47 self._headers.update(headers) 

48 

49 def clear_headers(self) -> None: 

50 """Clear the response headers.""" 

51 self._headers.clear() 

52 

53 def set_data(self, data: bytes) -> None: 

54 """Set the response data.""" 

55 self._data = data 

56 

57 def clear_data(self) -> None: 

58 """Clear the response data.""" 

59 self._data = b"" 

60 

61 

62def apply_middlewares(handler: Handler, middlewares: list[Middleware]) -> Handler: 

63 """Apply a list of middlewares to a handler.""" 

64 if not middlewares: 

65 return handler 

66 chained = _create_next_handler(handler) 

67 chained = _apply_middlewares_to_next_handler(chained, middlewares) 

68 return _create_final_handler(chained) 

69 

70 

71def _create_next_handler(handler: Handler) -> NextHandler: 

72 async def forward(request: Request) -> Response: 

73 req = _CapturedRequest(request) 

74 await handler(req) 

75 return req.get_response() 

76 

77 return forward 

78 

79 

80def _create_final_handler(forward: NextHandler) -> Handler: 

81 async def unwrap(request: Request) -> None: 

82 response = await forward(request) 

83 await response.origin.respond(response.data(), response.headers()) 

84 

85 return unwrap 

86 

87 

88def _apply_middlewares_to_next_handler( 

89 handler: NextHandler, middlewares: list[Middleware] 

90) -> NextHandler: 

91 """Apply a list of middlewares to a handler.""" 

92 if not middlewares: 

93 return handler 

94 chained = handler 

95 for middleware in reversed(middlewares): 

96 chained = _chain_next_handler_and_middleware(chained, middleware) 

97 return chained 

98 

99 

100def _chain_next_handler_and_middleware( 

101 handler: NextHandler, middleware: Middleware 

102) -> NextHandler: 

103 """Chain a middleware to a handler.""" 

104 

105 async def forward(request: Request) -> Response: 

106 return await middleware(request, handler) 

107 

108 return forward 

109 

110 

111class _CapturedRequest(Request): 

112 def __init__(self, request: Request): 

113 self._request = request 

114 self._response: Response | None = None 

115 

116 def subject(self) -> str: 

117 return self._request.subject() 

118 

119 def headers(self) -> dict[str, str]: 

120 return self._request.headers() 

121 

122 def data(self) -> bytes: 

123 return self._request.data() 

124 

125 async def respond(self, data: bytes, headers: dict[str, str] | None = None) -> None: 

126 self._response = Response(self._request, data, headers or {}) 

127 

128 def get_response(self) -> Response: 

129 if self._response is None: 

130 raise ValueError("No response was set") 

131 return self._response