Coverage for src/nats_contrib/request_many/executor.py: 97%

53 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 02:03 +0100

1from __future__ import annotations 

2 

3import asyncio 

4 

5from nats.aio.client import Client 

6from nats.aio.msg import Msg 

7from nats.errors import BadSubscriptionError 

8 

9 

10class RequestManyExecutor: 

11 def __init__( 

12 self, 

13 nc: Client, 

14 max_wait: float | None = None, 

15 ) -> None: 

16 self.nc = nc 

17 self.max_wait = max_wait or 0.5 

18 

19 async def __call__( 

20 self, 

21 subject: str, 

22 reply_inbox: str | None = None, 

23 payload: bytes | None = None, 

24 headers: dict[str, str] | None = None, 

25 max_wait: float | None = None, 

26 max_count: int | None = None, 

27 max_interval: float | None = None, 

28 stop_on_sentinel: bool = False, 

29 ) -> list[Msg]: 

30 """Request many responses from the same subject. 

31 

32 This function does not raise an error when no responses are received. 

33 

34 Responses are received until one of the following conditions is met: 

35 

36 - max_wait seconds have passed. 

37 - max_count responses have been received. 

38 - max_interval seconds have passed between responses. 

39 - A sentinel message is received and stop_on_sentinel is True. 

40 

41 Args: 

42 subject: The subject to send the request to. 

43 payload: The payload to send with the request. 

44 headers: The headers to send with the request. 

45 reply_inbox: The inbox to receive the responses in. A new inbox is created if None. 

46 max_wait: The maximum amount of time to wait for responses. Default max wait can be configured at the instance level. 

47 max_count: The maximum number of responses to accept. No limit by default. 

48 max_interval: The maximum amount of time between responses. No limit by default. 

49 stop_on_sentinel: Whether to stop when a sentinel message is received. False by default. 

50 """ 

51 if max_wait is None and max_interval is None: 

52 max_wait = self.max_wait 

53 # Create an inbox for the responses if one wasn't provided. 

54 if reply_inbox is None: 

55 reply_inbox = self.nc.new_inbox() 

56 # Create an empty list to store the responses. 

57 responses: list[Msg] = [] 

58 # Get the event loop 

59 loop = asyncio.get_event_loop() 

60 # Create an event to signal when the request is complete. 

61 event = asyncio.Event() 

62 # Create a marker to indicate that a message was received 

63 # and the interval has passed. 

64 last_received = loop.time() 

65 

66 # Define a callback to handle the responses. 

67 async def callback(msg: Msg) -> None: 

68 # Update the last received time. 

69 nonlocal last_received 

70 last_received = loop.time() 

71 # Check message headers 

72 # If the message is a 503 error, set the event and return. 

73 if msg.headers and msg.headers.get("Status") == "503": 

74 event.set() 

75 return 

76 # If we're stopping on a sentinel message, check for it 

77 # and don't append the message to the list of responses. 

78 if stop_on_sentinel and msg.data == b"": 

79 event.set() 

80 return 

81 # In all other cases, append the message to the list of responses. 

82 responses.append(msg) 

83 # And check if we've received all the responses. 

84 if len(responses) == max_count: 

85 event.set() 

86 

87 # Subscribe to the inbox. 

88 sub = await self.nc.subscribe( # pyright: ignore[reportUnknownMemberType] 

89 reply_inbox, 

90 cb=callback, 

91 max_msgs=max_count or 0, 

92 ) 

93 # Initialize a list of tasks to wait for. 

94 tasks: list[asyncio.Task[object]] = [] 

95 # Enter try/finally clause to ensure that the subscription is 

96 # unsubscribed from even if an error occurs. 

97 try: 

98 # Create task to wait for the stop event. 

99 tasks.append(asyncio.create_task(event.wait())) 

100 

101 # Add a task to wait for the max_wait time if needed 

102 if max_wait: 

103 tasks.append(asyncio.create_task(asyncio.sleep(max_wait))) 

104 

105 # Add a task to check the interval if needed 

106 if max_interval: 

107 

108 async def check_interval() -> None: 

109 nonlocal last_received 

110 while True: 

111 await asyncio.sleep(max_interval) 

112 if loop.time() - last_received > max_interval: 

113 event.set() 

114 return 

115 

116 tasks.append(asyncio.create_task(check_interval())) 

117 

118 # At this point the subscription is ready and all tasks are submitted 

119 # Publish the request. 

120 await self.nc.publish( 

121 subject, payload or b"", reply=reply_inbox, headers=headers 

122 ) 

123 # Wait for the first task to complete. 

124 await asyncio.wait( 

125 tasks, 

126 return_when=asyncio.FIRST_COMPLETED, 

127 ) 

128 # Always cancel tasks and unsubscribe from the inbox. 

129 finally: 

130 # Cancel the remaining tasks as soon as first one completes. 

131 for task in tasks: 

132 if not task.done(): 

133 task.cancel() 

134 # Unsubscribe from the inbox. 

135 try: 

136 await sub.unsubscribe() 

137 except BadSubscriptionError: 

138 # It's possible that auto-unsubscribe has already been called. 

139 pass 

140 

141 # Return the list of responses. 

142 return responses