Coverage for src/nats_contrib/request_many/iterator.py: 93%

90 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 01:54 +0100

1from __future__ import annotations 

2 

3import asyncio 

4from typing import Any 

5 

6from nats.aio.client import Client 

7from nats.aio.msg import Msg 

8from nats.aio.subscription import Subscription 

9from nats.errors import BadSubscriptionError 

10 

11 

12class RequestManyIterator: 

13 

14 def __init__( 

15 self, 

16 nc: Client, 

17 subject: str, 

18 inbox: str, 

19 payload: bytes | None = None, 

20 headers: dict[str, str] | None = None, 

21 max_wait: float | None = None, 

22 max_interval: float | None = None, 

23 max_count: int | None = None, 

24 stop_on_sentinel: bool = False, 

25 ) -> None: 

26 """Request many responses from the same subject. 

27 

28 Request is sent when entering the async context manager and unsubscribed when exiting. 

29 

30 The async iterator yieled by the context manager do not raise an 

31 error when no responses are received. 

32 

33 Responses are received until one of the following conditions is met: 

34 

35 - max_wait seconds have passed. 

36 - max_count responses have been received. 

37 - max_interval seconds have passed between responses. 

38 - A sentinel message is received and stop_on_sentinel is True. 

39 

40 When any of the condition is met, the async iterator raises StopAsyncIteration on 

41 the next call to __anext__, and the subscription is unsubscribed on exit. 

42 

43 Args: 

44 subject: The subject to send the request to. 

45 payload: The payload to send with the request. 

46 headers: The headers to send with the request. 

47 inbox: The inbox to receive the responses in. A new inbox is created if None. 

48 max_wait: The maximum amount of time to wait for responses. Default max wait can be configured at the instance level. 

49 max_count: The maximum number of responses to accept. No limit by default. 

50 max_interval: The maximum amount of time between responses. No limit by default. 

51 stop_on_sentinel: Whether to stop when a sentinel message is received. False by default. 

52 """ 

53 if max_wait is None and max_interval is None: 

54 max_wait = 0.5 

55 # Save all the arguments as instance variables. 

56 self.nc = nc 

57 self.subject = subject 

58 self.payload = payload 

59 self.headers = headers 

60 self.inbox = inbox 

61 self.max_wait = max_wait 

62 self.max_count = max_count 

63 self.max_interval = max_interval 

64 self.stop_on_sentinel = stop_on_sentinel 

65 # Initialize the state of the request many iterator 

66 self._sub: Subscription | None = None 

67 self._did_unsubscribe = False 

68 self._total_received = 0 

69 self._last_received = asyncio.get_event_loop().time() 

70 self._tasks: list[asyncio.Task[object]] = [] 

71 self._pending_task: asyncio.Task[Msg] | None = None 

72 

73 def __aiter__(self) -> RequestManyIterator: 

74 """RequestManyIterator is an asynchronous iterator.""" 

75 return self 

76 

77 async def __anext__(self) -> Msg: 

78 """Return the next message or raise StopAsyncIteration.""" 

79 if not self._sub: 79 ↛ 80line 79 didn't jump to line 80, because the condition on line 79 was never true

80 raise RuntimeError( 

81 "RequestManyIterator must be used as an async context manager" 

82 ) 

83 # Exit early if we've already unsubscribed 

84 if self._did_unsubscribe: 84 ↛ 85line 84 didn't jump to line 85, because the condition on line 84 was never true

85 raise StopAsyncIteration 

86 # Exit early if we received all the messages 

87 if self.max_count and self._total_received == self.max_count: 

88 await self.cleanup() 

89 raise StopAsyncIteration 

90 # Create a task to wait for the next message 

91 task: asyncio.Task[Msg] = asyncio.create_task(_fetch(self._sub)) 

92 self._pending_task = task 

93 # Wait for the next message or any of the other tasks to complete 

94 await asyncio.wait( 

95 [self._pending_task, *self._tasks], 

96 return_when=asyncio.FIRST_COMPLETED, 

97 ) 

98 # If the pending task is cancelled or not done, raise StopAsyncIteration 

99 if self._pending_task.cancelled() or not self._pending_task.done(): 

100 await self.cleanup() 

101 raise StopAsyncIteration 

102 # This will raise an exception if an error occurred within the task 

103 msg = self._pending_task.result() 

104 # Check message headers 

105 # If the message is a 503 error, raise StopAsyncIteration 

106 if msg.headers and msg.headers.get("Status") == "503": 

107 await self.cleanup() 

108 raise StopAsyncIteration 

109 # Always increment the total received count 

110 self._total_received += 1 

111 # Check if this is a sentinel message, and if so, raise StopAsyncIteration 

112 if self.stop_on_sentinel and msg.data == b"": 

113 await self.cleanup() 

114 raise StopAsyncIteration 

115 # Return the message 

116 return msg 

117 

118 async def __aenter__(self) -> RequestManyIterator: 

119 """Start the subscription and publish the request.""" 

120 # Start the subscription 

121 sub = await self.nc.subscribe( # pyright: ignore[reportUnknownMemberType] 

122 self.inbox, 

123 max_msgs=self.max_count or 0, 

124 ) 

125 # Save the subscription and the iterator 

126 self._sub = sub 

127 # Add a task to wait for the max_wait time if needed 

128 if self.max_wait: 

129 self._tasks.append(asyncio.create_task(asyncio.sleep(self.max_wait))) 

130 # Add a task to check the interval if needed 

131 if self.max_interval: 

132 interval = self.max_interval 

133 

134 async def check_interval() -> None: 

135 while True: 

136 await asyncio.sleep(interval) 

137 if asyncio.get_event_loop().time() - self._last_received > interval: 137 ↛ 135line 137 didn't jump to line 135, because the condition on line 137 was never false

138 await self.cleanup() 

139 return 

140 

141 self._tasks.append(asyncio.create_task(check_interval())) 

142 

143 # Publish the request 

144 await self.nc.publish( 

145 self.subject, self.payload or b"", reply=self.inbox, headers=self.headers 

146 ) 

147 # At this point the subscription is ready and all tasks are submitted 

148 return self 

149 

150 async def __aexit__(self, *args: Any, **kwargs: Any) -> None: 

151 """Unsubscribe from the inbox and cancel all the tasks.""" 

152 await self.cleanup() 

153 

154 async def cleanup(self) -> None: 

155 """Unsubscribe from the inbox and cancel all the tasks.""" 

156 if self._did_unsubscribe: 

157 return 

158 self._did_unsubscribe = True 

159 for task in self._tasks: 

160 if not task.done(): 

161 task.cancel() 

162 if self._pending_task and not self._pending_task.done(): 

163 self._pending_task.cancel() 

164 if self._sub: 164 ↛ exitline 164 didn't return from function 'cleanup', because the condition on line 164 was never false

165 await _unsubscribe(self._sub) 

166 

167 

168async def _unsubscribe(sub: Subscription) -> None: 

169 try: 

170 await sub.unsubscribe() 

171 except BadSubscriptionError: 

172 # It's possible that auto-unsubscribe has already been called. 

173 pass 

174 

175 

176async def _fetch(sub: Subscription) -> Msg: 

177 msg = await sub._pending_queue.get() # pyright: ignore[reportPrivateUsage] 

178 sub._pending_queue.task_done() # pyright: ignore[reportPrivateUsage] 

179 sub._pending_size -= len(msg.data) # pyright: ignore[reportPrivateUsage] 

180 return msg