diff --git a/kafka/net/selector.py b/kafka/net/selector.py index bbb1e2853..b89dcd31a 100644 --- a/kafka/net/selector.py +++ b/kafka/net/selector.py @@ -122,6 +122,7 @@ def __init__(self, **configs): if key in configs: self.config[key] = configs[key] + self._running = False self._closed = False self._stop = False self._selector = self.config['selector']() @@ -264,6 +265,11 @@ def remove_writer(self, fileobj): self.unregister_event(fileobj, selectors.EVENT_WRITE) def poll(self, timeout_ms=None, future=None): + if self._current: + raise RuntimError('Recursive access to net.poll!') + elif self._running: + raise RuntimeError('Concurrent access to net.poll!') + self._running = True start_at = time.monotonic() inner_timeout = timeout_ms / 1000 if timeout_ms is not None else None if future is not None and future.is_done: @@ -276,6 +282,7 @@ def poll(self, timeout_ms=None, future=None): inner_timeout = (timeout_ms / 1000) - (time.monotonic() - start_at) if inner_timeout <= 0: break + self._running = False def _poll_once(self, timeout=None): if self._ready: diff --git a/kafka/net/wakeup_notifier.py b/kafka/net/wakeup_notifier.py index d4757adb1..8822a5dc1 100644 --- a/kafka/net/wakeup_notifier.py +++ b/kafka/net/wakeup_notifier.py @@ -23,6 +23,8 @@ def _wakeup(self): self._fut.success(None) async def __call__(self, timeout_secs=None): + if self._fut is not None: + raise RuntimeError('Concurrent access to WakeupNotifier!') self._fut = Future() if timeout_secs is not None: try: diff --git a/test/integration/test_sasl_integration.py b/test/integration/test_sasl_integration.py index f75f6eccf..3cba1f5d1 100644 --- a/test/integration/test_sasl_integration.py +++ b/test/integration/test_sasl_integration.py @@ -74,7 +74,7 @@ def test_client(request, sasl_kafka): create_topics(sasl_kafka, [topic_name], num_partitions=1) client = KafkaNetClient(**client_params(sasl_kafka, 'client')) - client._manager.run(client._manager.bootstrap) + client._manager.run(client._manager.bootstrap_async) request = MetadataRequest(topics=None, version=1) timeout_at = time.time() + 1 future = client.send(None, request)