Skip to content
2 changes: 1 addition & 1 deletion asyncstdlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .asynctools import borrow, scoped_iter, await_each, any_iter, apply, sync
from .heapq import merge, nlargest, nsmallest

__version__ = "3.13.3"
__version__ = "3.14.0"

__all__ = [
"anext",
Expand Down
2 changes: 1 addition & 1 deletion asyncstdlib/asynctools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asyncio import iscoroutinefunction
from functools import wraps
from inspect import iscoroutinefunction
from typing import (
Union,
AsyncContextManager,
Expand Down
22 changes: 15 additions & 7 deletions asyncstdlib/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,35 @@ async def _zip_inner_strict(

async def map(
function: Union[Callable[..., R], Callable[..., Awaitable[R]]],
*iterable: AnyIterable[Any],
iterable: AnyIterable[Any],
/,
*iterables: AnyIterable[Any],
strict: bool = False,
) -> AsyncIterator[R]:
r"""
An async iterator mapping an (async) function to items from (async) iterables

:raises ValueError: if the ``iterables`` are not equal length and ``strict`` is set

At each step, ``map`` collects the next item from each iterable and calls
``function`` with all items; if ``function`` provides an awaitable,
``function`` with these items; if ``function`` provides an awaitable,
it is ``await``\ ed. The result is the next value of ``map``.
Barring sync/async translation, ``map`` is equivalent to
``(await function(*args) async for args in zip(iterables))``.

It is important that ``func`` receives *one* item from *each* iterable at
every step. For *n* ``iterable``, ``func`` must take *n* positional arguments.
Similar to :py:func:`~.zip`, ``map`` is exhausted as soon as its
*first* argument is exhausted.
every step. For *n* ``iterables``, ``func`` must take *n* positional arguments.
Similar to :py:func:`~.zip`, ``map`` is exhausted as soon as any of its `iterables`
is exhausted.
When called with ``strict=True``, all ``iterables`` must be of same length;
in this mode ``map`` raises :py:exc:`ValueError` if any ``iterables`` are not
exhausted with the others.

The ``function`` may be a regular or async callable.
Multiple ``iterable`` may be mixed regular and async iterables.
Multiple ``iterables`` may be mixed regular and async iterables.
"""
function = _awaitify(function)
async with ScopedIter(zip(*iterable)) as args_iter:
async with ScopedIter(zip(iterable, *iterables, strict=strict)) as args_iter:
async for args in args_iter:
result = function(*args)
yield await result
Expand Down
12 changes: 12 additions & 0 deletions asyncstdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,30 @@ def map(
function: Callable[[T1], Awaitable[R]],
__it1: AnyIterable[T1],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1], R],
__it1: AnyIterable[T1],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -110,6 +114,7 @@ def map(
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -118,6 +123,7 @@ def map(
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -127,6 +133,7 @@ def map(
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -136,6 +143,7 @@ def map(
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -146,6 +154,7 @@ def map(
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -156,6 +165,7 @@ def map(
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
/,
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -167,6 +177,7 @@ def map(
__it5: AnyIterable[Any],
/,
*iterable: AnyIterable[Any],
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -178,6 +189,7 @@ def map(
__it5: AnyIterable[Any],
/,
*iterable: AnyIterable[Any],
strict: bool = ...,
) -> AsyncIterator[R]: ...
@overload
async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
Expand Down
3 changes: 2 additions & 1 deletion asyncstdlib/functools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import iscoroutinefunction
from inspect import iscoroutinefunction
from typing import (
Callable,
Awaitable,
Expand Down Expand Up @@ -281,6 +281,7 @@ def decorator(
async def reduce(
function: Union[Callable[[T, T], T], Callable[[T, T], Awaitable[T]]],
iterable: AnyIterable[T],
/,
initial: T = __REDUCE_SENTINEL, # type: ignore
) -> T:
"""
Expand Down
20 changes: 16 additions & 4 deletions asyncstdlib/functools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,27 @@ def cached_property(
) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ...
@overload
async def reduce(
function: Callable[[T1, T2], Awaitable[T1]], iterable: AnyIterable[T2], initial: T1
function: Callable[[T1, T2], Awaitable[T1]],
iterable: AnyIterable[T2],
/,
initial: T1,
) -> T1: ...
@overload
async def reduce(
function: Callable[[T, T], Awaitable[T]], iterable: AnyIterable[T]
function: Callable[[T, T], Awaitable[T]],
iterable: AnyIterable[T],
/,
) -> T: ...
@overload
async def reduce(
function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1
function: Callable[[T1, T2], T1],
iterable: AnyIterable[T2],
/,
initial: T1,
) -> T1: ...
@overload
async def reduce(function: Callable[[T, T], T], iterable: AnyIterable[T]) -> T: ...
async def reduce(
function: Callable[[T, T], T],
iterable: AnyIterable[T],
/,
) -> T: ...
1 change: 1 addition & 0 deletions asyncstdlib/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AsyncGenerator,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from typing_extensions import TypeAlias

Expand Down
6 changes: 5 additions & 1 deletion docs/source/api/builtins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@ Iterator transforming

The ``strict`` parameter.

.. autofunction:: map(function: (T, ...) → (await) R, iterable: (async) iter T, ...)
.. autofunction:: map(function: (T, ...) → (await) R, iterable: (async) iter T, ..., /, strict: bool = True)
:async-for: :R

.. versionadded:: 3.14.0

The ``strict`` parameter.

.. autofunction:: enumerate(iterable: (async) iter T, start=0)
:async-for: :(int, T)

Expand Down
55 changes: 41 additions & 14 deletions unittests/test_builtins.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import random
from typing import Any, Callable, Coroutine, TypeVar

import pytest

import asyncstdlib as a

from .utility import sync, asyncify, awaitify

COR = TypeVar("COR", bound=Callable[..., Coroutine[Any, Any, Any]])

def hide_coroutine(corofunc):
def wrapper(*args, **kwargs):

def hide_coroutine(corofunc: COR) -> COR:
"""Make a coroutine function look like a regular function returning a coroutine"""

def wrapper(*args, **kwargs): # type: ignore
return corofunc(*args, **kwargs)

return wrapper
return wrapper # type: ignore


@sync
Expand Down Expand Up @@ -94,7 +99,7 @@ async def __aiter__(self):

@sync
async def test_map_as():
async def map_op(value):
async def map_op(value: int) -> int:
return value * 2

assert [value async for value in a.map(map_op, range(5))] == list(range(0, 10, 2))
Expand All @@ -105,7 +110,7 @@ async def map_op(value):

@sync
async def test_map_sa():
def map_op(value):
async def map_op(value: int) -> int:
return value * 2

assert [value async for value in a.map(map_op, asyncify(range(5)))] == list(
Expand All @@ -118,7 +123,7 @@ def map_op(value):

@sync
async def test_map_aa():
async def map_op(value):
async def map_op(value: int) -> int:
return value * 2

assert [value async for value in a.map(map_op, asyncify(range(5)))] == list(
Expand All @@ -130,6 +135,28 @@ async def map_op(value):
] == list(range(10, 20, 4))


@pytest.mark.parametrize(
"itrs",
[
(range(4), range(5), range(5)),
(range(5), range(4), range(5)),
(range(5), range(5), range(4)),
],
)
@sync
async def test_map_strict_unequal(itrs: "tuple[range, ...]"):
def triple_sum(x: int, y: int, z: int) -> int:
return x + y + z

# no error without strict
async for _ in a.map(triple_sum, *itrs):
pass
# error with strict
with pytest.raises(ValueError):
async for _ in a.map(triple_sum, *itrs, strict=True):
pass


@sync
async def test_max_default():
assert await a.max((), default=3) == 3
Expand All @@ -142,7 +169,7 @@ async def test_max_default():

@sync
async def test_max_sa():
async def minus(x):
async def minus(x: int) -> int:
return -x

assert await a.max(asyncify((1, 2, 3, 4))) == 4
Expand All @@ -167,7 +194,7 @@ async def test_min_default():

@sync
async def test_min_sa():
async def minus(x):
async def minus(x: int) -> int:
return -x

assert await a.min(asyncify((1, 2, 3, 4))) == 1
Expand All @@ -180,7 +207,7 @@ async def minus(x):

@sync
async def test_filter_as():
async def map_op(value):
async def map_op(value: int) -> bool:
return value % 2 == 0

assert [value async for value in a.filter(map_op, range(5))] == list(range(0, 5, 2))
Expand All @@ -194,7 +221,7 @@ async def map_op(value):

@sync
async def test_filter_sa():
def map_op(value):
def map_op(value: int) -> bool:
return value % 2 == 0

assert [value async for value in a.filter(map_op, asyncify(range(5)))] == list(
Expand All @@ -208,7 +235,7 @@ def map_op(value):

@sync
async def test_filter_aa():
async def map_op(value):
async def map_op(value: int) -> bool:
return value % 2 == 0

assert [value async for value in a.filter(map_op, asyncify(range(5)))] == list(
Expand Down Expand Up @@ -286,7 +313,7 @@ async def test_types():
@pytest.mark.parametrize("sortable", sortables)
@pytest.mark.parametrize("reverse", [True, False])
@sync
async def test_sorted_direct(sortable, reverse):
async def test_sorted_direct(sortable: "list[int] | list[float]", reverse: bool):
assert await a.sorted(sortable, reverse=reverse) == sorted(
sortable, reverse=reverse
)
Expand All @@ -305,12 +332,12 @@ async def test_sorted_direct(sortable, reverse):
async def test_sorted_stable():
values = [-i for i in range(20)]

def collision_key(x):
def collision_key(x: int) -> int:
return x // 2

# test the test...
assert sorted(values, key=collision_key) != [
item for key, item in sorted([(collision_key(i), i) for i in values])
item for _, item in sorted([(collision_key(i), i) for i in values])
]
# test the implementation
assert await a.sorted(values, key=awaitify(collision_key)) == sorted(
Expand Down
Loading