Coverage for src/nats_contrib/request_many/utils.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-25 01:51 +0100

1from __future__ import annotations 

2 

3from typing import Any, AsyncContextManager, AsyncIterator, Callable, Generic, TypeVar 

4 

5T = TypeVar("T") 

6R = TypeVar("R") 

7 

8 

9def transform( 

10 source: AsyncContextManager[AsyncIterator[T]], 

11 map: Callable[[T], R], 

12) -> AsyncContextManager[AsyncIterator[R]]: 

13 """Create a new async context manager which will 

14 yield an async iterator that applies the map function to each value 

15 yielded by the source async iterator. 

16 

17 It is useful for example to transform the return value of the 

18 `request_many_iter` method. 

19 """ 

20 return TransformAsyncIterator(source, map) 

21 

22 

23class TransformAsyncIterator(Generic[T, R]): 

24 def __init__( 

25 self, 

26 source: AsyncContextManager[AsyncIterator[T]], 

27 map: Callable[[T], R], 

28 ) -> None: 

29 self.factory = source 

30 self.iterator: AsyncIterator[T] | None = None 

31 self.transform = map 

32 

33 def __aiter__(self) -> TransformAsyncIterator[T, R]: 

34 return self 

35 

36 async def __anext__(self) -> R: 

37 if not self.iterator: 

38 raise RuntimeError( 

39 "TransformAsyncIterator must be used as an async context manager" 

40 ) 

41 next_value = await self.iterator.__anext__() 

42 return self.transform(next_value) 

43 

44 async def __aenter__(self) -> AsyncIterator[R]: 

45 self.iterator = await self.factory.__aenter__() 

46 return self 

47 

48 async def __aexit__(self, *args: Any, **kwargs: Any) -> None: 

49 await self.factory.__aexit__(*args, **kwargs)