diff --git a/kafka/admin/client.py b/kafka/admin/client.py index 61e8d5d7f..bd078741d 100644 --- a/kafka/admin/client.py +++ b/kafka/admin/client.py @@ -219,10 +219,11 @@ def __init__(self, **configs): ) # Goal: migrate all self._client calls -> self._manager (skipping compat layer) self._manager = self._client._manager + self._net = self._manager._net # Run all IO on a dedicated background thread; public admin methods # block on cross-thread Events via self._manager.run(...). - self._manager.start() + self._net.start() # Bootstrap on __init__ self._manager.bootstrap(self.config['bootstrap_timeout_ms']) @@ -237,9 +238,10 @@ def close(self): log.info("KafkaAdminClient already closed.") return + self._closed = True self._metrics.close() self._client.close() - self._closed = True + self._net.stop() log.debug("KafkaAdminClient is now closed.") def _validate_timeout(self, timeout_ms): @@ -261,7 +263,7 @@ async def _refresh_controller_id(self, timeout_ms=30000): controller_id = response.controller_id if controller_id == -1: log.warning("Controller ID not available, got -1") - await self._manager._net.sleep(1) + await self._net.sleep(1) continue return controller_id else: diff --git a/kafka/net/manager.py b/kafka/net/manager.py index f920f56ae..113965015 100644 --- a/kafka/net/manager.py +++ b/kafka/net/manager.py @@ -4,7 +4,6 @@ import random import socket import ssl -import threading import time from .inet import create_connection @@ -78,9 +77,6 @@ def __init__(self, net, **configs): self.broker_version_data = None self._bootstrap_future = None self._bootstrap_wakeup = WakeupNotifier(self._net) - self._io_thread = None - self._pending_waiters = {} # event -> state dict, for pending run() waiters - self._pending_waiters_lock = threading.Lock() if self.config['metrics']: self._sensors = KafkaManagerMetrics( self.config['metrics'], self.config['metric_group_prefix'], self._conns) @@ -348,34 +344,6 @@ def close(self, node_id=None, timeout_ms=None): for conn in list(self._conns.values()): conn.close() self.cluster.close() - self.stop(timeout_ms) - - def start(self): - """Spawn a daemon IO thread that owns the event loop. Idempotent.""" - if self._io_thread is not None: - return - t = threading.Thread(target=self._net.run_forever, - name='kafka-io-%s' % self.config['client_id'], - daemon=True) - self._io_thread = t - t.start() - - def stop(self, timeout_ms=None): - """Signal the IO thread to exit and join it. Fails any pending run() - waiters with KafkaConnectionError. Idempotent.""" - t = self._io_thread - if t is None: - self._net.drain() - return - self._io_thread = None - self._net.stop() - t.join(timeout_ms / 1000 if timeout_ms is not None else None) - with self._pending_waiters_lock: - waiters = list(self._pending_waiters.items()) - self._pending_waiters.clear() - for event, state in waiters: - state['exception'] = Errors.KafkaConnectionError('Manager stopped') - event.set() async def wait_for(self, future, timeout_ms): """Await `future` with a timeout in ms. Raises KafkaTimeoutError on timeout. @@ -409,24 +377,6 @@ def _on_timeout(): except ValueError: pass - async def _invoke(self, coro, args): - """Invoke coro/awaitable/function and fully resolve the result. - - If the result is itself a Future (e.g. send() returning an unresolved - Future), it is awaited so callers receive the resolved value. - """ - if inspect.iscoroutinefunction(coro): - result = await coro(*args) - elif hasattr(coro, '__await__'): - result = await coro - else: - result = coro(*args) - if inspect.iscoroutine(result) or hasattr(result, '__await__'): - result = await result - while isinstance(result, Future): - result = await result - return result - def call_soon(self, coro, *args): """Accepts a coroutine / awaitable / function and schedules it on the event loop. @@ -434,16 +384,7 @@ def call_soon(self, coro, *args): Returns: Future """ - if hasattr(coro, '__await__'): - assert not args, 'initiated coroutine does not accept args' - future = Future() - async def wrapper(): - try: - future.success(await self._invoke(coro, args)) - except BaseException as exc: - future.failure(exc) - self._net.call_soon_threadsafe(wrapper) - return future + return self._net.call_soon_with_future(coro, *args) def run(self, coro, *args): """Schedules coro on the event loop, blocks until complete, returns value or raises. @@ -455,28 +396,4 @@ def run(self, coro, *args): If no IO thread is running, falls back to driving the loop on the caller thread (legacy behavior). """ - if self._io_thread is None: - future = self.call_soon(coro, *args) - self._net.poll(future=future) - if future.exception is not None: - raise future.exception - return future.value - - event = threading.Event() - state = {'value': None, 'exception': None} - async def waiter(): - try: - state['value'] = await self._invoke(coro, args) - except BaseException as exc: - state['exception'] = exc - finally: - with self._pending_waiters_lock: - self._pending_waiters.pop(event, None) - event.set() - with self._pending_waiters_lock: - self._pending_waiters[event] = state - self._net.call_soon_threadsafe(waiter) - event.wait() - if state['exception'] is not None: - raise state['exception'] - return state['value'] + return self._net.run(coro, *args) diff --git a/kafka/net/selector.py b/kafka/net/selector.py index 6d77acfaa..c3d395055 100644 --- a/kafka/net/selector.py +++ b/kafka/net/selector.py @@ -8,7 +8,9 @@ import threading import time +import kafka.errors as Errors from kafka.future import Future +from kafka.version import __version__ log = logging.getLogger(__name__) @@ -117,6 +119,7 @@ def exception(self): class NetworkSelector: DEFAULT_CONFIG = { + 'client_id': 'kafka-python-' + __version__, 'selector': selectors.DefaultSelector, # Warn (or, in debug mode, raise) when a single ready-task step takes # longer than this many seconds. A coroutine that hits this threshold @@ -163,14 +166,13 @@ def __init__(self, **configs): self._wakeup_r.setblocking(False) self._wakeup_w.setblocking(False) self._selector.register(self._wakeup_r, selectors.EVENT_READ, (None, None)) + self._io_thread = None + self._pending_waiters = {} # event -> state dict, for pending run() waiters + self._pending_waiters_lock = threading.Lock() def __str__(self): return '' % (len(self._ready), len(self._scheduled), len(self._selector.get_map())) - def run(self): - while self._scheduled or self._ready: - self._poll_once() - def run_forever(self): """Run the event loop until stop() is called. Intended to be driven by a dedicated IO thread. Wake-ups from other threads must go through @@ -178,23 +180,71 @@ def run_forever(self): self._stop = False while not self._stop: self._poll_once() - - def stop(self): + self.drain() + + def start(self): + """Spawn a daemon IO thread that owns the event loop. Idempotent.""" + if self._io_thread is not None: + return + t = threading.Thread(target=self.run_forever, + name='kafka-io-%s' % self.config['client_id'], + daemon=True) + self._io_thread = t + t.start() + + def stop(self, timeout_ms=None): """Signal run_forever() to exit. Safe to call from any thread.""" + if self._stop or self._io_thread is None: + return self._stop = True self.wakeup() - - def run_until_done(self, task_or_future): - if not isinstance(task_or_future, (Future, Task)): - task_or_future = Task(task_or_future) - if isinstance(task_or_future, Task): - self.call_soon(task_or_future) - while not task_or_future.is_done: - self._poll_once() - return task_or_future - - def drain(self): - while self._ready: + self._io_thread.join(timeout_ms / 1000 if timeout_ms is not None else None) + self._io_thread = None + with self._pending_waiters_lock: + waiters = list(self._pending_waiters.items()) + self._pending_waiters.clear() + for event, state in waiters: + state['exception'] = Errors.KafkaConnectionError('Manager stopped') + event.set() + + def run(self, coro, *args): + """Schedules coro on the event loop, blocks until complete, returns value or raises. + + If an IO thread is running (via start()), the caller thread blocks on + a cross-thread Event while the coroutine runs on the IO thread. Safe + to call concurrently from multiple caller threads. + + If no IO thread is running, falls back to driving the loop on the + caller thread (legacy behavior). + """ + if self._io_thread is None: + future = self.call_soon_with_future(coro, *args) + self.poll(future=future) + if future.exception is not None: + raise future.exception + return future.value + + event = threading.Event() + state = {'value': None, 'exception': None} + async def waiter(): + try: + state['value'] = await self._invoke(coro, *args) + except BaseException as exc: + state['exception'] = exc + finally: + with self._pending_waiters_lock: + self._pending_waiters.pop(event, None) + event.set() + with self._pending_waiters_lock: + self._pending_waiters[event] = state + self.call_soon_threadsafe(waiter) + event.wait() + if state['exception'] is not None: + raise state['exception'] + return state['value'] + + def drain(self, scheduled=False): + while self._ready or (scheduled and self._scheduled): self._poll_once() def call_at(self, when, task): @@ -218,6 +268,41 @@ def call_soon(self, task): self._pending_tasks.add(task) return task + def call_soon_threadsafe(self, callback): + task = self.call_soon(callback) + self.wakeup() + return task + + def call_soon_with_future(self, coro, *args): + if hasattr(coro, '__await__'): + assert not args, 'initiated coroutine does not accept args' + future = Future() + async def wrapper(): + try: + future.success(await self._invoke(coro, *args)) + except BaseException as exc: + future.failure(exc) + self.call_soon_threadsafe(wrapper) + return future + + async def _invoke(self, coro, *args): + """Invoke coro/awaitable/function and fully resolve the result. + + If the result is itself a Future (e.g. send() returning an unresolved + Future), it is awaited so callers receive the resolved value. + """ + if inspect.iscoroutinefunction(coro): + result = await coro(*args) + elif hasattr(coro, '__await__'): + result = await coro + else: + result = coro(*args) + if inspect.iscoroutine(result) or hasattr(result, '__await__'): + result = await result + while isinstance(result, Future): + result = await result + return result + def unschedule(self, task): if task.scheduled_at is not None: self._scheduled.remove((task.scheduled_at, task)) @@ -402,11 +487,6 @@ def wakeup(self): except (BlockingIOError, OSError): pass - def call_soon_threadsafe(self, callback): - task = self.call_soon(callback) - self.wakeup() - return task - def _rebuild_wakeup_socketpair(self): for s in (self._wakeup_r, self._wakeup_w): try: @@ -423,7 +503,12 @@ def _rebuild_wakeup_socketpair(self): self._selector.register(self._wakeup_r, selectors.EVENT_READ, (None, None)) def close(self): + if self._closed: + return self._closed = True + if self._io_thread is not None: + self.stop() + self.drain() for s in (self._wakeup_r, self._wakeup_w): try: self._selector.unregister(s) diff --git a/test/net/test_connection.py b/test/net/test_connection.py index 99940fde2..3086e9215 100644 --- a/test/net/test_connection.py +++ b/test/net/test_connection.py @@ -89,7 +89,7 @@ def mock_send_request(request, **kwargs): return f conn._send_request = mock_send_request - net.run_until_done(conn._check_version()) + net.run(conn._check_version()) return requests_sent def test_first_request_is_max_version(self, net): @@ -367,7 +367,7 @@ def test_sasl_authenticate_handshake_error(self, net): f.success(handshake_response) conn._send_request = MagicMock(return_value=f) - net.run_until_done(conn._sasl_authenticate()) + net.run(conn._sasl_authenticate()) transport.abort.assert_called_once() def test_sasl_authenticate_mechanism_not_supported(self, net): @@ -392,7 +392,7 @@ def test_sasl_authenticate_mechanism_not_supported(self, net): f.success(handshake_response) conn._send_request = MagicMock(return_value=f) - net.run_until_done(conn._sasl_authenticate()) + net.run(conn._sasl_authenticate()) transport.abort.assert_called_once() def test_sasl_authenticate_success(self, net): @@ -428,7 +428,7 @@ def mock_send_request(request): return f conn._send_request = mock_send_request - net.run_until_done(conn._sasl_authenticate()) + net.run(conn._sasl_authenticate()) # Should not have closed -- auth succeeded assert conn.initializing # still initializing, _init_complete not called by _sasl_authenticate @@ -465,5 +465,5 @@ def mock_send_request(request): return f conn._send_request = mock_send_request - net.run_until_done(conn._sasl_authenticate()) + net.run(conn._sasl_authenticate()) transport.abort.assert_called_once() diff --git a/test/net/test_inet.py b/test/net/test_inet.py index f1e4cb423..ac611d237 100644 --- a/test/net/test_inet.py +++ b/test/net/test_inet.py @@ -32,49 +32,45 @@ def test_immediate_connect(self): net = NetworkSelector() sock = MagicMock() sock.connect_ex.return_value = 0 - task = net.run_until_done(connect_sock(net, sock, ('127.0.0.1', 9092))) - assert task.result is sock + result = net.run(connect_sock(net, sock, ('127.0.0.1', 9092))) + assert result is sock def test_eisconn(self): net = NetworkSelector() sock = MagicMock() sock.connect_ex.return_value = errno.EISCONN - task = net.run_until_done(connect_sock(net, sock, ('127.0.0.1', 9092))) - assert task.result is sock + result = net.run(connect_sock(net, sock, ('127.0.0.1', 9092))) + assert result is sock def test_connection_refused(self): net = NetworkSelector() sock = MagicMock() sock.connect_ex.return_value = errno.ECONNREFUSED - task = net.run_until_done(connect_sock(net, sock, ('127.0.0.1', 9092))) - assert task.is_done - assert isinstance(task.exception, Errors.KafkaConnectionError) + with pytest.raises(Errors.KafkaConnectionError): + net.run(connect_sock(net, sock, ('127.0.0.1', 9092))) def test_socket_error_uses_errno(self): net = NetworkSelector() sock = MagicMock() sock.connect_ex.side_effect = socket.error(errno.ECONNREFUSED, 'refused') - task = net.run_until_done(connect_sock(net, sock, ('127.0.0.1', 9092))) - assert isinstance(task.exception, Errors.KafkaConnectionError) + with pytest.raises(Errors.KafkaConnectionError): + net.run(connect_sock(net, sock, ('127.0.0.1', 9092))) class TestCreateConnection: def test_dns_failure(self): net = NetworkSelector() with patch('kafka.net.inet.dns_lookup', return_value=[]): - task = net.run_until_done( - create_connection(net, 'badhost', 9092)) - assert isinstance(task.exception, Errors.KafkaConnectionError) - assert 'DNS' in str(task.exception) + with pytest.raises(Errors.KafkaConnectionError, match='DNS'): + net.run(create_connection(net, 'badhost', 9092)) def test_socket_init_failure(self): net = NetworkSelector() fake_addr = [(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('127.0.0.1', 9092))] with patch('kafka.net.inet.dns_lookup', return_value=fake_addr), \ patch('kafka.net.inet.socket.socket', side_effect=OSError('no socket')): - task = net.run_until_done( - create_connection(net, 'host', 9092)) - assert isinstance(task.exception, Errors.KafkaConnectionError) + with pytest.raises(Errors.KafkaConnectionError): + net.run(create_connection(net, 'host', 9092)) def test_successful_connection(self): net = NetworkSelector() @@ -83,9 +79,9 @@ def test_successful_connection(self): mock_sock.connect_ex.return_value = 0 with patch('kafka.net.inet.dns_lookup', return_value=fake_addr), \ patch('kafka.net.inet.socket.socket', return_value=mock_sock): - task = net.run_until_done( + result = net.run( create_connection(net, 'host', 9092)) - assert task.result is mock_sock + assert result is mock_sock mock_sock.setblocking.assert_called_with(False) def test_tries_multiple_addresses(self): @@ -99,9 +95,9 @@ def test_tries_multiple_addresses(self): sockets = iter([mock_sock1, mock_sock2]) with patch('kafka.net.inet.dns_lookup', return_value=[addr1, addr2]), \ patch('kafka.net.inet.socket.socket', side_effect=lambda *a: next(sockets)): - task = net.run_until_done( + result = net.run( create_connection(net, 'host', 9092)) - assert task.result is mock_sock2 + assert result is mock_sock2 def test_so_error_after_wait(self): net = NetworkSelector() @@ -128,8 +124,8 @@ def test_proxy_connect_ex_called(self): sock = MagicMock() proxy = MagicMock() proxy.connect_ex.return_value = 0 - task = net.run_until_done(connect_sock(net, sock, ('127.0.0.1', 9092), proxy=proxy)) - assert task.result is sock + result = net.run(connect_sock(net, sock, ('127.0.0.1', 9092), proxy=proxy)) + assert result is sock proxy.connect_ex.assert_called_once_with(('127.0.0.1', 9092)) sock.connect_ex.assert_not_called() @@ -143,8 +139,8 @@ def test_proxy_blocking_io_retries(self): # Use real socket fd so wait_write works mock_sock = MagicMock() mock_sock.fileno.return_value = wsock.fileno() - task = net.run_until_done(connect_sock(net, mock_sock, ('127.0.0.1', 9092), proxy=proxy)) - assert task.result is mock_sock + result = net.run(connect_sock(net, mock_sock, ('127.0.0.1', 9092), proxy=proxy)) + assert result is mock_sock assert proxy.connect_ex.call_count == 2 rsock.close() wsock.close() @@ -154,8 +150,8 @@ def test_proxy_connection_refused(self): sock = MagicMock() proxy = MagicMock() proxy.connect_ex.return_value = errno.ECONNREFUSED - task = net.run_until_done(connect_sock(net, sock, ('127.0.0.1', 9092), proxy=proxy)) - assert isinstance(task.exception, Errors.KafkaConnectionError) + with pytest.raises(Errors.KafkaConnectionError): + net.run(connect_sock(net, sock, ('127.0.0.1', 9092), proxy=proxy)) class TestCreateConnectionWithProxy: @@ -167,11 +163,11 @@ def test_proxy_creates_socket_via_wrapper(self): mock_proxy.socket.return_value = mock_sock mock_proxy.connect_ex.return_value = 0 with patch('kafka.net.inet.Socks5Wrapper', return_value=mock_proxy) as mock_cls: - task = net.run_until_done( + result = net.run( create_connection(net, 'broker', 9092, socks5_proxy='socks5://proxy:1080')) mock_cls.assert_called_once_with('socks5://proxy:1080', socket.AF_UNSPEC) mock_proxy.socket.assert_called_once() - assert task.result is mock_sock + assert result is mock_sock def test_proxy_remote_dns_skips_local_lookup(self): net = NetworkSelector() @@ -184,7 +180,7 @@ def test_proxy_remote_dns_skips_local_lookup(self): patch('kafka.net.inet.dns_lookup') as mock_dns: mock_cls.return_value = mock_proxy mock_cls.use_remote_lookup.return_value = True - task = net.run_until_done( + result = net.run( create_connection(net, 'broker', 9092, socks5_proxy='socks5h://proxy:1080')) mock_dns.assert_not_called() @@ -196,7 +192,7 @@ def test_no_proxy_uses_direct_socket(self): with patch('kafka.net.inet.dns_lookup', return_value=fake_addr), \ patch('kafka.net.inet.socket.socket', return_value=mock_sock), \ patch('kafka.net.inet.Socks5Wrapper') as mock_cls: - task = net.run_until_done( + result = net.run( create_connection(net, 'host', 9092)) mock_cls.assert_not_called() - assert task.result is mock_sock + assert result is mock_sock diff --git a/test/net/test_selector.py b/test/net/test_selector.py index 71b5f56c1..666e530dc 100644 --- a/test/net/test_selector.py +++ b/test/net/test_selector.py @@ -161,16 +161,16 @@ def task(): net.call_soon(task) assert len(net._ready) == 1 - def test_run_simple(self): + def test_drain_simple(self): net = NetworkSelector() results = [] async def task(): results.append('ran') net.call_soon(task) - net.run() + net.drain() assert results == ['ran'] - def test_run_multiple_tasks(self): + def test_drain_multiple_tasks(self): net = NetworkSelector() results = [] async def task(n): @@ -178,7 +178,7 @@ async def task(n): net.call_soon(task(1)) net.call_soon(task(2)) net.call_soon(task(3)) - net.run() + net.drain() assert results == [1, 2, 3] def test_call_later(self): @@ -188,7 +188,7 @@ async def task(): results.append('delayed') net.call_later(0.01, task) assert len(net._scheduled) == 1 - net.run() + net.drain(scheduled=True) assert results == ['delayed'] def test_call_at(self): @@ -198,25 +198,15 @@ async def task(): results.append('at') when = time.monotonic() + 0.01 net.call_at(when, task) - net.run() + net.drain(scheduled=True) assert results == ['at'] - def test_run_until_done_with_task(self): + def test_run_with_task(self): net = NetworkSelector() async def task(): return 42 - result = net.run_until_done(task) - assert result.is_done - assert result.result == 42 - - def test_run_until_done_with_future(self): - net = NetworkSelector() - f = Future() - async def resolve(): - f.success('hello') - net.call_soon(resolve) - net.run_until_done(f) - assert f.value == 'hello' + result = net.run(task) + assert result == 42 def test_sleep(self): net = NetworkSelector() @@ -243,7 +233,7 @@ def test_sleep_in_coroutine(self): async def task(): await net.sleep(0.01) results.append('after_sleep') - net.run_until_done(task) + net.run(task) assert results == ['after_sleep'] def test_poll_with_future(self): @@ -371,7 +361,7 @@ async def waiter(): val = await f results.append(val) - net.run_until_done(waiter) + net.run(waiter) assert results == [99] def test_await_future_already_failed(self): @@ -386,7 +376,7 @@ async def waiter(): except ValueError as e: errors.append(str(e)) - net.run_until_done(waiter) + net.run(waiter) assert errors == ['oops'] def test_wakeup(self): @@ -448,7 +438,7 @@ async def task(): val = await f results.append(val) - net.run_until_done(task) + net.run(task) assert results == [42] def test_await_future_pending(self):