diff --git a/kafka/net/manager.py b/kafka/net/manager.py index 6d6c3e5e2..cf2790739 100644 --- a/kafka/net/manager.py +++ b/kafka/net/manager.py @@ -1,5 +1,6 @@ import copy import logging +import inspect import random import socket import ssl @@ -72,12 +73,8 @@ def __init__(self, net, cluster, **configs): def least_used_connections(self): return sorted(filter(lambda conn: conn.connected, self._conns.values()), key=lambda conn: conn.transport.last_activity) - async def _do_bootstrap(self, future, deadline): - while not future.is_done: - if deadline is not None and time.monotonic() >= deadline: - future.failure(Errors.KafkaConnectionError( - 'Unable to bootstrap from %s' % (self.cluster.config['bootstrap_servers'],))) - return + async def _do_bootstrap(self, deadline): + while deadline is None or time.monotonic() < deadline: bootstrap_broker = random.choice(self.cluster.bootstrap_brokers()) try: conn = self.get_connection(bootstrap_broker.node_id, pop_on_close=False, refresh_metadata_on_err=False) @@ -96,8 +93,7 @@ async def _do_bootstrap(self, future, deadline): await conn.init_future except Errors.IncompatibleBrokerVersion: log.error('Did you attempt to connect to a kafka controller (no metadata support)?') - future.failure(conn.init_future.exception) - return + raise except Exception as exc: self._conns.pop(bootstrap_broker.node_id, conn).close(exc) continue @@ -111,19 +107,20 @@ async def _do_bootstrap(self, future, deadline): continue self._conns.pop(bootstrap_broker.node_id, conn).close() log.info('Bootstrap complete: %s', self.cluster) - future.success(True) - return + return True except Exception as e: self.cluster.failed_update(e) continue + else: + raise Errors.KafkaConnectionError( + 'Unable to bootstrap from %s' % (self.cluster.config['bootstrap_servers'],)) def bootstrap(self, timeout_ms=None): if self._bootstrap_future is not None and not self._bootstrap_future.is_done: return self._bootstrap_future - self._bootstrap_future = Future() - self._bootstrap_future.add_errback(lambda exc: log.error('Bootstrap failed: %s', exc)) deadline = None if timeout_ms is None else time.monotonic() + timeout_ms / 1000 - self._net.call_soon(lambda: self._do_bootstrap(self._bootstrap_future, deadline)) + self._bootstrap_future = self.call_soon(self._do_bootstrap, deadline) + self._bootstrap_future.add_errback(lambda exc: log.error('Bootstrap failed: %s', exc)) return self._bootstrap_future @property @@ -298,12 +295,11 @@ def update_metadata(self): if self._metadata_future is not None and not self._metadata_future.is_done: return self._metadata_future self.cluster.request_update() - self._metadata_future = Future() - self._net.call_soon(lambda: self._do_update_metadata(self._metadata_future)) + self._metadata_future = self.call_soon(self._do_update_metadata) return self._metadata_future - async def _do_update_metadata(self, future): - while not future.is_done: + async def _do_update_metadata(self): + while True: node_id = self.least_loaded_node() if node_id is None: if not self.bootstrapped: @@ -320,13 +316,14 @@ async def _do_update_metadata(self, future): log.debug("Sending metadata request %s to node %s", request, node_id) response = await conn.send_request(request) self.cluster.update_metadata(response) - future.success(True) + return True except Exception as exc: self.cluster.failed_update(exc) - future.failure(exc) - # Schedule next periodic refresh - ttl = self.cluster.ttl() / 1000 - self._net.call_later(max(0, ttl), self.update_metadata) + raise + finally: + # Schedule next periodic refresh + ttl = self.cluster.ttl() / 1000 + self._net.call_later(max(0, ttl), self.update_metadata) def close(self, node_id=None): if node_id is not None: @@ -339,3 +336,32 @@ def close(self, node_id=None): def poll(self, timeout_ms=None, future=None): return self._net.poll(timeout_ms=timeout_ms, future=future) + + def call_soon(self, coro, *args): + """Accepts a coroutine / awaitable / function and schedules it on the event loop. + + Returns: Future + """ + if hasattr(coro, '__await__'): + assert not args, 'initiated coroutine does not accept args' + future = Future() + async def wrapper(): + try: + if inspect.iscoroutinefunction(coro): + future.success(await coro(*args)) + elif hasattr(coro, '__await__'): + future.success(await coro) + else: + future.success(coro(*args)) + except Exception as exc: + future.failure(exc) + self._net.call_soon(wrapper) + return future + + def run(self, coro, *args): + """Schedules coro on the event loop, blocks until complete, returns value or raises.""" + future = self.call_soon(coro, *args) + self.poll(future=future) + if future.exception is not None: + raise future.exception + return future.value diff --git a/test/net/test_manager.py b/test/net/test_manager.py index 50faffda0..b3b360da0 100644 --- a/test/net/test_manager.py +++ b/test/net/test_manager.py @@ -334,3 +334,48 @@ def test_close_nonexistent_node(self, manager): def test_close_no_connections(self, manager): # Should not raise manager.close() + + +class TestKafkaConnectionManagerRun: + def test_run_function(self, manager): + def test_coro(): + return 42 + assert manager.run(test_coro) == 42 + + def test_run_async_coro_function(self, manager): + async def test_coro(): + return 100 + assert manager.run(test_coro) == 100 + + def test_run_async_coro_with_args(self, manager): + async def test_coro(foo): + return foo + assert manager.run(test_coro, 123) == 123 + + def test_run_async_coro(self, manager): + async def test_coro(): + return 49 + assert manager.run(test_coro()) == 49 + + def test_run_async_chain(self, manager): + async def test_coro_foo(): + return 'foo!' + async def test_coro_bar(): + return await test_coro_foo() + assert manager.run(test_coro_bar()) == 'foo!' + + def test_run_raises(self, manager): + async def bad_coro(): + raise ValueError('bad_coro') + with pytest.raises(ValueError, match='bad_coro'): + manager.run(bad_coro) + + def test_call_soon_does_not_raise(self, manager): + async def bad_coro(): + raise ValueError('bad_coro') + future = manager.call_soon(bad_coro) + assert not future.is_done + manager.poll(future=future) + assert future.failed() + assert isinstance(future.exception, ValueError) + assert future.exception.args[0] == 'bad_coro'