Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions kafka/admin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ def __init__(self, **configs):
)
# Goal: migrate all self._client calls -> self._manager (skipping compat layer)
self._manager = self._client._manager
self._net = self._manager._net

# Run all IO on a dedicated background thread; public admin methods
# block on cross-thread Events via self._manager.run(...).
self._manager.start()
self._net.start()

# Bootstrap on __init__
self._manager.bootstrap(self.config['bootstrap_timeout_ms'])
Expand All @@ -237,9 +238,10 @@ def close(self):
log.info("KafkaAdminClient already closed.")
return

self._closed = True
self._metrics.close()
self._client.close()
self._closed = True
self._net.stop()
log.debug("KafkaAdminClient is now closed.")

def _validate_timeout(self, timeout_ms):
Expand All @@ -261,7 +263,7 @@ async def _refresh_controller_id(self, timeout_ms=30000):
controller_id = response.controller_id
if controller_id == -1:
log.warning("Controller ID not available, got -1")
await self._manager._net.sleep(1)
await self._net.sleep(1)
continue
return controller_id
else:
Expand Down
87 changes: 2 additions & 85 deletions kafka/net/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
import socket
import ssl
import threading
import time

from .inet import create_connection
Expand Down Expand Up @@ -78,9 +77,6 @@ def __init__(self, net, **configs):
self.broker_version_data = None
self._bootstrap_future = None
self._bootstrap_wakeup = WakeupNotifier(self._net)
self._io_thread = None
self._pending_waiters = {} # event -> state dict, for pending run() waiters
self._pending_waiters_lock = threading.Lock()
if self.config['metrics']:
self._sensors = KafkaManagerMetrics(
self.config['metrics'], self.config['metric_group_prefix'], self._conns)
Expand Down Expand Up @@ -348,34 +344,6 @@ def close(self, node_id=None, timeout_ms=None):
for conn in list(self._conns.values()):
conn.close()
self.cluster.close()
self.stop(timeout_ms)

def start(self):
"""Spawn a daemon IO thread that owns the event loop. Idempotent."""
if self._io_thread is not None:
return
t = threading.Thread(target=self._net.run_forever,
name='kafka-io-%s' % self.config['client_id'],
daemon=True)
self._io_thread = t
t.start()

def stop(self, timeout_ms=None):
"""Signal the IO thread to exit and join it. Fails any pending run()
waiters with KafkaConnectionError. Idempotent."""
t = self._io_thread
if t is None:
self._net.drain()
return
self._io_thread = None
self._net.stop()
t.join(timeout_ms / 1000 if timeout_ms is not None else None)
with self._pending_waiters_lock:
waiters = list(self._pending_waiters.items())
self._pending_waiters.clear()
for event, state in waiters:
state['exception'] = Errors.KafkaConnectionError('Manager stopped')
event.set()

async def wait_for(self, future, timeout_ms):
"""Await `future` with a timeout in ms. Raises KafkaTimeoutError on timeout.
Expand Down Expand Up @@ -409,41 +377,14 @@ def _on_timeout():
except ValueError:
pass

async def _invoke(self, coro, args):
"""Invoke coro/awaitable/function and fully resolve the result.

If the result is itself a Future (e.g. send() returning an unresolved
Future), it is awaited so callers receive the resolved value.
"""
if inspect.iscoroutinefunction(coro):
result = await coro(*args)
elif hasattr(coro, '__await__'):
result = await coro
else:
result = coro(*args)
if inspect.iscoroutine(result) or hasattr(result, '__await__'):
result = await result
while isinstance(result, Future):
result = await result
return result

def call_soon(self, coro, *args):
"""Accepts a coroutine / awaitable / function and schedules it on the event loop.

Thread-safe.

Returns: Future
"""
if hasattr(coro, '__await__'):
assert not args, 'initiated coroutine does not accept args'
future = Future()
async def wrapper():
try:
future.success(await self._invoke(coro, args))
except BaseException as exc:
future.failure(exc)
self._net.call_soon_threadsafe(wrapper)
return future
return self._net.call_soon_with_future(coro, *args)

def run(self, coro, *args):
"""Schedules coro on the event loop, blocks until complete, returns value or raises.
Expand All @@ -455,28 +396,4 @@ def run(self, coro, *args):
If no IO thread is running, falls back to driving the loop on the
caller thread (legacy behavior).
"""
if self._io_thread is None:
future = self.call_soon(coro, *args)
self._net.poll(future=future)
if future.exception is not None:
raise future.exception
return future.value

event = threading.Event()
state = {'value': None, 'exception': None}
async def waiter():
try:
state['value'] = await self._invoke(coro, args)
except BaseException as exc:
state['exception'] = exc
finally:
with self._pending_waiters_lock:
self._pending_waiters.pop(event, None)
event.set()
with self._pending_waiters_lock:
self._pending_waiters[event] = state
self._net.call_soon_threadsafe(waiter)
event.wait()
if state['exception'] is not None:
raise state['exception']
return state['value']
return self._net.run(coro, *args)
131 changes: 108 additions & 23 deletions kafka/net/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import threading
import time

import kafka.errors as Errors
from kafka.future import Future
from kafka.version import __version__


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,6 +119,7 @@ def exception(self):

class NetworkSelector:
DEFAULT_CONFIG = {
'client_id': 'kafka-python-' + __version__,
'selector': selectors.DefaultSelector,
# Warn (or, in debug mode, raise) when a single ready-task step takes
# longer than this many seconds. A coroutine that hits this threshold
Expand Down Expand Up @@ -163,38 +166,85 @@ def __init__(self, **configs):
self._wakeup_r.setblocking(False)
self._wakeup_w.setblocking(False)
self._selector.register(self._wakeup_r, selectors.EVENT_READ, (None, None))
self._io_thread = None
self._pending_waiters = {} # event -> state dict, for pending run() waiters
self._pending_waiters_lock = threading.Lock()

def __str__(self):
return '<NetworkSelector ready=%d scheduled=%d waiting=%d>' % (len(self._ready), len(self._scheduled), len(self._selector.get_map()))

def run(self):
while self._scheduled or self._ready:
self._poll_once()

def run_forever(self):
"""Run the event loop until stop() is called. Intended to be driven by
a dedicated IO thread. Wake-ups from other threads must go through
call_soon_threadsafe() so the select() loop returns promptly."""
self._stop = False
while not self._stop:
self._poll_once()

def stop(self):
self.drain()

def start(self):
"""Spawn a daemon IO thread that owns the event loop. Idempotent."""
if self._io_thread is not None:
return
t = threading.Thread(target=self.run_forever,
name='kafka-io-%s' % self.config['client_id'],
daemon=True)
self._io_thread = t
t.start()

def stop(self, timeout_ms=None):
"""Signal run_forever() to exit. Safe to call from any thread."""
if self._stop or self._io_thread is None:
return
self._stop = True
self.wakeup()

def run_until_done(self, task_or_future):
if not isinstance(task_or_future, (Future, Task)):
task_or_future = Task(task_or_future)
if isinstance(task_or_future, Task):
self.call_soon(task_or_future)
while not task_or_future.is_done:
self._poll_once()
return task_or_future

def drain(self):
while self._ready:
self._io_thread.join(timeout_ms / 1000 if timeout_ms is not None else None)
self._io_thread = None
with self._pending_waiters_lock:
waiters = list(self._pending_waiters.items())
self._pending_waiters.clear()
for event, state in waiters:
state['exception'] = Errors.KafkaConnectionError('Manager stopped')
event.set()

def run(self, coro, *args):
"""Schedules coro on the event loop, blocks until complete, returns value or raises.

If an IO thread is running (via start()), the caller thread blocks on
a cross-thread Event while the coroutine runs on the IO thread. Safe
to call concurrently from multiple caller threads.

If no IO thread is running, falls back to driving the loop on the
caller thread (legacy behavior).
"""
if self._io_thread is None:
future = self.call_soon_with_future(coro, *args)
self.poll(future=future)
if future.exception is not None:
raise future.exception
return future.value

event = threading.Event()
state = {'value': None, 'exception': None}
async def waiter():
try:
state['value'] = await self._invoke(coro, *args)
except BaseException as exc:
state['exception'] = exc
finally:
with self._pending_waiters_lock:
self._pending_waiters.pop(event, None)
event.set()
with self._pending_waiters_lock:
self._pending_waiters[event] = state
self.call_soon_threadsafe(waiter)
event.wait()
if state['exception'] is not None:
raise state['exception']
return state['value']

def drain(self, scheduled=False):
while self._ready or (scheduled and self._scheduled):
self._poll_once()

def call_at(self, when, task):
Expand All @@ -218,6 +268,41 @@ def call_soon(self, task):
self._pending_tasks.add(task)
return task

def call_soon_threadsafe(self, callback):
task = self.call_soon(callback)
self.wakeup()
return task

def call_soon_with_future(self, coro, *args):
if hasattr(coro, '__await__'):
assert not args, 'initiated coroutine does not accept args'
future = Future()
async def wrapper():
try:
future.success(await self._invoke(coro, *args))
except BaseException as exc:
future.failure(exc)
self.call_soon_threadsafe(wrapper)
return future

async def _invoke(self, coro, *args):
"""Invoke coro/awaitable/function and fully resolve the result.

If the result is itself a Future (e.g. send() returning an unresolved
Future), it is awaited so callers receive the resolved value.
"""
if inspect.iscoroutinefunction(coro):
result = await coro(*args)
elif hasattr(coro, '__await__'):
result = await coro
else:
result = coro(*args)
if inspect.iscoroutine(result) or hasattr(result, '__await__'):
result = await result
while isinstance(result, Future):
result = await result
return result

def unschedule(self, task):
if task.scheduled_at is not None:
self._scheduled.remove((task.scheduled_at, task))
Expand Down Expand Up @@ -402,11 +487,6 @@ def wakeup(self):
except (BlockingIOError, OSError):
pass

def call_soon_threadsafe(self, callback):
task = self.call_soon(callback)
self.wakeup()
return task

def _rebuild_wakeup_socketpair(self):
for s in (self._wakeup_r, self._wakeup_w):
try:
Expand All @@ -423,7 +503,12 @@ def _rebuild_wakeup_socketpair(self):
self._selector.register(self._wakeup_r, selectors.EVENT_READ, (None, None))

def close(self):
if self._closed:
return
self._closed = True
if self._io_thread is not None:
self.stop()
self.drain()
for s in (self._wakeup_r, self._wakeup_w):
try:
self._selector.unregister(s)
Expand Down
10 changes: 5 additions & 5 deletions test/net/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def mock_send_request(request, **kwargs):
return f

conn._send_request = mock_send_request
net.run_until_done(conn._check_version())
net.run(conn._check_version())
return requests_sent

def test_first_request_is_max_version(self, net):
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_sasl_authenticate_handshake_error(self, net):
f.success(handshake_response)
conn._send_request = MagicMock(return_value=f)

net.run_until_done(conn._sasl_authenticate())
net.run(conn._sasl_authenticate())
transport.abort.assert_called_once()

def test_sasl_authenticate_mechanism_not_supported(self, net):
Expand All @@ -392,7 +392,7 @@ def test_sasl_authenticate_mechanism_not_supported(self, net):
f.success(handshake_response)
conn._send_request = MagicMock(return_value=f)

net.run_until_done(conn._sasl_authenticate())
net.run(conn._sasl_authenticate())
transport.abort.assert_called_once()

def test_sasl_authenticate_success(self, net):
Expand Down Expand Up @@ -428,7 +428,7 @@ def mock_send_request(request):
return f
conn._send_request = mock_send_request

net.run_until_done(conn._sasl_authenticate())
net.run(conn._sasl_authenticate())
# Should not have closed -- auth succeeded
assert conn.initializing # still initializing, _init_complete not called by _sasl_authenticate

Expand Down Expand Up @@ -465,5 +465,5 @@ def mock_send_request(request):
return f
conn._send_request = mock_send_request

net.run_until_done(conn._sasl_authenticate())
net.run(conn._sasl_authenticate())
transport.abort.assert_called_once()
Loading
Loading