Coverage for src/nats_contrib/request_many/utils.py: 100%
23 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 01:51 +0100
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-25 01:51 +0100
1from __future__ import annotations
3from typing import Any, AsyncContextManager, AsyncIterator, Callable, Generic, TypeVar
5T = TypeVar("T")
6R = TypeVar("R")
9def transform(
10 source: AsyncContextManager[AsyncIterator[T]],
11 map: Callable[[T], R],
12) -> AsyncContextManager[AsyncIterator[R]]:
13 """Create a new async context manager which will
14 yield an async iterator that applies the map function to each value
15 yielded by the source async iterator.
17 It is useful for example to transform the return value of the
18 `request_many_iter` method.
19 """
20 return TransformAsyncIterator(source, map)
23class TransformAsyncIterator(Generic[T, R]):
24 def __init__(
25 self,
26 source: AsyncContextManager[AsyncIterator[T]],
27 map: Callable[[T], R],
28 ) -> None:
29 self.factory = source
30 self.iterator: AsyncIterator[T] | None = None
31 self.transform = map
33 def __aiter__(self) -> TransformAsyncIterator[T, R]:
34 return self
36 async def __anext__(self) -> R:
37 if not self.iterator:
38 raise RuntimeError(
39 "TransformAsyncIterator must be used as an async context manager"
40 )
41 next_value = await self.iterator.__anext__()
42 return self.transform(next_value)
44 async def __aenter__(self) -> AsyncIterator[R]:
45 self.iterator = await self.factory.__aenter__()
46 return self
48 async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
49 await self.factory.__aexit__(*args, **kwargs)