From f0d84a9ee45480de42afbbbaefc02524988d6b69 Mon Sep 17 00:00:00 2001 From: Adrian Chaves Date: Thu, 26 Mar 2026 12:08:24 +0100 Subject: [PATCH 1/2] Add trust_env --- CHANGES.rst | 8 +++++++ tests/test_async.py | 18 ++++++++++++++- tests/test_auth.py | 7 +++++- tests/test_main.py | 16 ++++++++++++- tests/test_sync.py | 7 ++++++ tests/test_utils.py | 18 ++++++++++++++- zyte_api/__main__.py | 55 +++++++++++++++++++++++++++++--------------- zyte_api/_async.py | 27 +++++++++++++++------- zyte_api/_sync.py | 6 +++++ 9 files changed, 131 insertions(+), 31 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index e80dbe5..1edc383 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,14 @@ Changes ======= +0.9.0 (unreleased) +------------------ + +* Added an opt-in ``trust_env`` parameter to :class:`~zyte_api.AsyncZyteAPI` + and :class:`~zyte_api.ZyteAPI`, and an opt-in ``--trust-env`` CLI flag, to + allow honoring environment-based network settings (e.g. ``HTTP_PROXY`` and + ``HTTPS_PROXY``). + 0.8.2 (2026-02-10) ------------------ diff --git a/tests/test_async.py b/tests/test_async.py index 50e38b8..f99f70b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -2,11 +2,12 @@ import asyncio from typing import TYPE_CHECKING, Any -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest from zyte_api import AggressiveRetryFactory, AsyncZyteAPI, RequestError +from zyte_api._utils import create_session from zyte_api.aio.client import AsyncClient from zyte_api.apikey import NoApiKey from zyte_api.errors import ParsedError @@ -54,6 +55,21 @@ def test_api_key(client_cls): client_cls() +@pytest.mark.asyncio +async def test_session_inherits_client_trust_env(mockserver): + client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"), trust_env=True) + async with client.session() as session: + assert session._session._trust_env is True + + +@pytest.mark.asyncio +async def test_get_creates_session_with_client_trust_env(mockserver): + client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"), trust_env=True) + with patch("zyte_api._async.create_session", wraps=create_session) as create_session_mock: + await client.get({"url": "https://a.example"}) + assert create_session_mock.call_args.kwargs["trust_env"] is True + + @pytest.mark.parametrize( ("client_cls", "get_method"), ( diff --git a/tests/test_auth.py b/tests/test_auth.py index 86838ff..f2c7da6 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -14,6 +14,11 @@ def run_zyte_api(args, env, mockserver): + base_env = { + key: value + for key, value in environ.items() + if key not in {"ZYTE_API_KEY", "ZYTE_API_ETH_KEY"} + } with NamedTemporaryFile("w") as url_list: url_list.write("https://a.example\n") url_list.flush() @@ -29,7 +34,7 @@ def run_zyte_api(args, env, mockserver): ], capture_output=True, check=False, - env={**environ, **env}, + env={**base_env, **env}, ) diff --git a/tests/test_main.py b/tests/test_main.py index 6c2c71e..99f635d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -11,7 +11,7 @@ import pytest from zyte_api import RequestError -from zyte_api.__main__ import run +from zyte_api.__main__ import _get_argument_parser, run if TYPE_CHECKING: from collections.abc import Iterable @@ -108,6 +108,7 @@ async def test_run(queries, expected_response, store_errors, exception): api_url = "https://example.com" api_key = "fake_key" retry_errors = True + trust_env = True # Create a mock for AsyncZyteAPI async_client_mock = Mock() @@ -138,8 +139,15 @@ async def test_run(queries, expected_response, store_errors, exception): api_key=api_key, retry_errors=retry_errors, store_errors=store_errors, + trust_env=trust_env, ) + assert async_client_mock.call_args.kwargs["trust_env"] is True + create_session_mock.assert_called_once_with( + connection_pool_size=n_conn, + trust_env=True, + ) + assert get_json_content(temporary_file) == expected_response tmp_path.unlink() @@ -218,6 +226,12 @@ def test_empty_input(mockserver): assert result.stderr == b"No input queries found. Is the input file empty?\n" +def test_trust_env_flag_parsing() -> None: + parser = _get_argument_parser() + args = parser.parse_args(["--trust-env", "--api-key", "a", "README.rst"]) + assert args.trust_env is True + + def test_intype_txt_implicit(mockserver): result = _run(input_="https://a.example", mockserver=mockserver) assert not result.returncode diff --git a/tests/test_sync.py b/tests/test_sync.py index 4e3ca6b..0c5e561 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -3,6 +3,7 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock +from unittest.mock import patch import pytest @@ -19,6 +20,12 @@ def test_api_key(): ZyteAPI() +def test_trust_env_is_forwarded(): + with patch("zyte_api._sync.AsyncZyteAPI") as async_client: + ZyteAPI(api_key="a", trust_env=True) + assert async_client.call_args.kwargs["trust_env"] is True + + def test_get(mockserver): client = ZyteAPI(api_key="a", api_url=mockserver.urljoin("/")) expected_result = { diff --git a/tests/test_utils.py b/tests/test_utils.py index b0d89eb..d0c677f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -12,6 +12,21 @@ async def test_create_session_custom_connector(): custom_connector = TCPConnector(limit=1850) session = create_session(connector=custom_connector) assert session.connector == custom_connector + await session.close() + + +@pytest.mark.asyncio +async def test_create_session_trust_env_disabled_by_default(): + session = create_session() + assert session._trust_env is False + await session.close() + + +@pytest.mark.asyncio +async def test_create_session_trust_env_can_be_enabled(): + session = create_session(trust_env=True) + assert session._trust_env is True + await session.close() @pytest.mark.parametrize( @@ -121,4 +136,5 @@ async def test_deprecated_create_session(): DeprecationWarning, match=r"^zyte_api\.aio\.client\.create_session is deprecated", ): - _create_session() + session = _create_session() + await session.close() diff --git a/zyte_api/__main__.py b/zyte_api/__main__.py index 382b48a..f43a28d 100644 --- a/zyte_api/__main__.py +++ b/zyte_api/__main__.py @@ -4,8 +4,10 @@ import argparse import asyncio +from contextlib import nullcontext import json import logging +from pathlib import Path import random import sys from typing import IO, Any, Literal @@ -42,6 +44,7 @@ async def run( retry_errors: bool = True, store_errors: bool | None = None, eth_key: str | None = None, + trust_env: bool = False, ) -> None: if stop_on_errors is not _UNSET: warn( @@ -65,9 +68,13 @@ def write_output(content: Any) -> None: elif eth_key: auth_kwargs["eth_key"] = eth_key client = AsyncZyteAPI( - n_conn=n_conn, api_url=api_url, retrying=retrying, **auth_kwargs + n_conn=n_conn, + api_url=api_url, + retrying=retrying, + trust_env=trust_env, + **auth_kwargs, ) - async with create_session(connection_pool_size=n_conn) as session: + async with create_session(connection_pool_size=n_conn, trust_env=trust_env) as session: result_iter = client.iter( queries=queries, session=session, @@ -128,7 +135,6 @@ def _get_argument_parser(program_name: str = "zyte-api") -> argparse.ArgumentPar ) p.add_argument( "INPUT", - type=argparse.FileType("r", encoding="utf8"), help=( "Path to an input file (see 'Command-line client > Input file' in " "the docs for details)." @@ -151,8 +157,7 @@ def _get_argument_parser(program_name: str = "zyte-api") -> argparse.ArgumentPar p.add_argument( "--output", "-o", - default=sys.stdout, - type=argparse.FileType("w", encoding="utf8"), + default=None, help=( "Path for the output file. Results are written into the output " "file in JSON Lines format.\n" @@ -225,6 +230,14 @@ def _get_argument_parser(program_name: str = "zyte-api") -> argparse.ArgumentPar ), action="store_true", ) + p.add_argument( + "--trust-env", + help=( + "Enable environment-based network settings such as HTTP_PROXY and " + "HTTPS_PROXY for Zyte API requests." + ), + action="store_true", + ) return p @@ -234,7 +247,8 @@ def _main(program_name: str = "zyte-api") -> None: args = p.parse_args() logging.basicConfig(stream=sys.stderr, level=getattr(logging, args.loglevel)) - queries = read_input(args.INPUT, args.intype) + with Path(args.INPUT).open(encoding="utf8") as input_fp: + queries = read_input(input_fp, args.intype) if not queries: print("No input queries found. Is the input file empty?", file=sys.stderr) sys.exit(-1) @@ -245,23 +259,26 @@ def _main(program_name: str = "zyte-api") -> None: queries = queries[: args.limit] logger.info( - f"Loaded {len(queries)} urls from {args.INPUT.name}; shuffled: {args.shuffle}" + f"Loaded {len(queries)} urls from {args.INPUT}; shuffled: {args.shuffle}" ) logger.info(f"Running Zyte API (connections: {args.n_conn})") - loop = asyncio.get_event_loop() - coro = run( - queries, - out=args.output, - n_conn=args.n_conn, - api_url=args.api_url, - api_key=args.api_key, - eth_key=args.eth_key, - retry_errors=not args.dont_retry_errors, - store_errors=args.store_errors, + run_kwargs = { + "n_conn": args.n_conn, + "api_url": args.api_url, + "api_key": args.api_key, + "eth_key": args.eth_key, + "retry_errors": not args.dont_retry_errors, + "store_errors": args.store_errors, + "trust_env": args.trust_env, + } + output_context = ( + nullcontext(sys.stdout) + if args.output is None + else Path(args.output).open("w", encoding="utf8") ) - loop.run_until_complete(coro) - loop.close() + with output_context as out: + asyncio.run(run(queries, out=out, **run_kwargs)) if __name__ == "__main__": diff --git a/zyte_api/_async.py b/zyte_api/_async.py index 9f00f61..39588d4 100644 --- a/zyte_api/_async.py +++ b/zyte_api/_async.py @@ -45,6 +45,7 @@ def _post_func( class _AsyncSession: def __init__(self, client: AsyncZyteAPI, **session_kwargs: Any): self._client: AsyncZyteAPI = client + session_kwargs.setdefault("trust_env", client.trust_env) self._session: aiohttp.ClientSession = create_session( client.n_conn, **session_kwargs ) @@ -123,6 +124,7 @@ def __init__( retrying: AsyncRetrying | None = None, user_agent: str | None = None, eth_key: str | None = None, + trust_env: bool = False, ): if retrying is not None and not isinstance(retrying, AsyncRetrying): raise ValueError( @@ -134,6 +136,7 @@ def __init__( self.agg_stats = AggStats() self.retrying = retrying or zyte_api_retrying self.user_agent = user_agent or USER_AGENT + self.trust_env = trust_env self._semaphore = asyncio.Semaphore(n_conn) self._auth: str | _x402Handler self.auth: AuthInfo @@ -190,6 +193,10 @@ async def get( ) -> dict[str, Any]: """Asynchronous equivalent to :meth:`ZyteAPI.get`.""" retrying = retrying or self.retrying + owned_session: aiohttp.ClientSession | None = None + if session is None: + owned_session = create_session(self.n_conn, trust_env=self.trust_env) + session = owned_session post = _post_func(session) url = self.api_url + endpoint @@ -257,14 +264,18 @@ async def request() -> dict[str, Any]: request = retrying.wraps(request) try: - # Try to make a request - result = await request() - self.agg_stats.n_success += 1 - except Exception: - self.agg_stats.n_fatal_errors += 1 - raise - - return result + try: + # Try to make a request + result = await request() + self.agg_stats.n_success += 1 + except Exception: + self.agg_stats.n_fatal_errors += 1 + raise + + return result + finally: + if owned_session is not None: + await owned_session.close() def iter( self, diff --git a/zyte_api/_sync.py b/zyte_api/_sync.py index 95afb8d..21d2cc0 100644 --- a/zyte_api/_sync.py +++ b/zyte_api/_sync.py @@ -104,6 +104,10 @@ class ZyteAPI: *user_agent* is the user agent string reported to Zyte API. Defaults to ``python-zyte-api/``. + *trust_env* controls whether :mod:`aiohttp` honors environment-based + network settings (e.g. ``HTTP_PROXY`` and ``HTTPS_PROXY``). Defaults to + ``False``. + .. tip:: To change the ``User-Agent`` header sent to a target website, use :http:`request:customHttpRequestHeaders` instead. """ @@ -117,6 +121,7 @@ def __init__( retrying: AsyncRetrying | None = None, user_agent: str | None = None, eth_key: str | None = None, + trust_env: bool = False, ): self._async_client = AsyncZyteAPI( api_key=api_key, @@ -125,6 +130,7 @@ def __init__( retrying=retrying, user_agent=user_agent, eth_key=eth_key, + trust_env=trust_env, ) def get( From bad7988be67726f5164ee960568706ce17f471a2 Mon Sep 17 00:00:00 2001 From: Adrian Chaves Date: Thu, 26 Mar 2026 13:00:45 +0100 Subject: [PATCH 2/2] Apply feedback and address deprecations --- tests/test_async.py | 4 +++- tests/test_sync.py | 3 +-- zyte_api/__main__.py | 35 +++++++++++++++++++++++------------ 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/test_async.py b/tests/test_async.py index f99f70b..42a5f3a 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -65,7 +65,9 @@ async def test_session_inherits_client_trust_env(mockserver): @pytest.mark.asyncio async def test_get_creates_session_with_client_trust_env(mockserver): client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"), trust_env=True) - with patch("zyte_api._async.create_session", wraps=create_session) as create_session_mock: + with patch( + "zyte_api._async.create_session", wraps=create_session + ) as create_session_mock: await client.get({"url": "https://a.example"}) assert create_session_mock.call_args.kwargs["trust_env"] is True diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c5e561..79c060b 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -2,8 +2,7 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any -from unittest.mock import AsyncMock -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest diff --git a/zyte_api/__main__.py b/zyte_api/__main__.py index f43a28d..ab055c4 100644 --- a/zyte_api/__main__.py +++ b/zyte_api/__main__.py @@ -4,12 +4,12 @@ import argparse import asyncio -from contextlib import nullcontext import json import logging -from pathlib import Path import random import sys +from contextlib import nullcontext +from pathlib import Path from typing import IO, Any, Literal from warnings import warn @@ -74,7 +74,9 @@ def write_output(content: Any) -> None: trust_env=trust_env, **auth_kwargs, ) - async with create_session(connection_pool_size=n_conn, trust_env=trust_env) as session: + async with create_session( + connection_pool_size=n_conn, trust_env=trust_env + ) as session: result_iter = client.iter( queries=queries, session=session, @@ -247,8 +249,15 @@ def _main(program_name: str = "zyte-api") -> None: args = p.parse_args() logging.basicConfig(stream=sys.stderr, level=getattr(logging, args.loglevel)) - with Path(args.INPUT).open(encoding="utf8") as input_fp: - queries = read_input(input_fp, args.intype) + if args.INPUT == "-": + with nullcontext(sys.stdin) as input_fp: + queries = read_input(input_fp, args.intype) + else: + try: + with Path(args.INPUT).open(encoding="utf8") as input_fp: + queries = read_input(input_fp, args.intype) + except OSError as e: + p.error(f"Cannot open input file {args.INPUT!r}: {e}") if not queries: print("No input queries found. Is the input file empty?", file=sys.stderr) sys.exit(-1) @@ -272,13 +281,15 @@ def _main(program_name: str = "zyte-api") -> None: "store_errors": args.store_errors, "trust_env": args.trust_env, } - output_context = ( - nullcontext(sys.stdout) - if args.output is None - else Path(args.output).open("w", encoding="utf8") - ) - with output_context as out: - asyncio.run(run(queries, out=out, **run_kwargs)) + if args.output is None or args.output == "-": + with nullcontext(sys.stdout) as out: + asyncio.run(run(queries, out=out, **run_kwargs)) + else: + try: + with Path(args.output).open("w", encoding="utf8") as out: + asyncio.run(run(queries, out=out, **run_kwargs)) + except OSError as e: + p.error(f"Cannot open output file {args.output!r}: {e}") if __name__ == "__main__":