Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions kafka/net/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def __init__(self, **configs):
self._selector = self.config['selector']()
self._scheduled = [] # managed by heapq
self._ready = collections.deque()
# Strong refs to every Task that hasn't completed yet. Without this,
# a Task suspended on an externally-unreachable awaitable (e.g. a
# Future created and awaited inside the Task's own coroutine) forms
# an orphan cycle and is subject to gc collection. Keeping every
# pending Task rooted on the selector itself prevents the cycle from
# ever being garbage-eligible. Tasks are removed when they raise
# StopIteration (normal completion) or BaseException (raised) inside
# _poll_once. This mirrors asyncio's loop._tasks weakset.
self._pending_tasks = set()
self._current = None
self._wakeup_r, self._wakeup_w = socket.socketpair()
self._wakeup_r.setblocking(False)
Expand Down Expand Up @@ -193,6 +202,7 @@ def call_at(self, when, task):
task = Task(task)
task.scheduled_at = when
heapq.heappush(self._scheduled, (when, task))
self._pending_tasks.add(task)
return task

def call_later(self, delay, task):
Expand All @@ -205,6 +215,7 @@ def call_soon(self, task):
if not isinstance(task, Task):
task = Task(task)
self._ready.append(task)
self._pending_tasks.add(task)
return task

def unschedule(self, task):
Expand Down Expand Up @@ -347,10 +358,14 @@ def _poll_once(self, timeout=None):
event = self._current()

except StopIteration:
pass
# Task ran to completion. Drop the strong ref so the Task
# (and its coroutine, frames, locals) is now collectable.
self._pending_tasks.discard(self._current)

except BaseException as e:
log.exception(e)
# Same as StopIteration -- task is done either way.
self._pending_tasks.discard(self._current)

else:
if isinstance(event, KernelEvent):
Expand Down Expand Up @@ -388,8 +403,9 @@ def wakeup(self):
pass

def call_soon_threadsafe(self, callback):
self.call_soon(callback)
task = self.call_soon(callback)
self.wakeup()
return task

def _rebuild_wakeup_socketpair(self):
for s in (self._wakeup_r, self._wakeup_w):
Expand Down
35 changes: 35 additions & 0 deletions test/net/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,38 @@ async def bad_coro():
assert future.failed()
assert isinstance(future.exception, ValueError)
assert future.exception.args[0] == 'bad_coro'

def test_run_survives_gc_during_poll(self, manager, monkeypatch):
"""Regression: an aggressive gc.collect() between _poll_once
iterations must not close orphan-cycle suspended coroutines and mask
the real result with GeneratorExit.

The wrapper Future returned by manager.call_soon pins its Task via
a no-op callback so the cycle (Future_yielded <-> _poll_once cb <->
Task <-> coroutine <-> Future_yielded) has an external reference
for as long as the wrapper Future is pending.
"""
import gc
from kafka.net.selector import NetworkSelector

# Force a GC cycle on every _poll_once entry to deterministically
# trigger the orphan-collection race that was masking timeouts in CI.
orig_poll_once = NetworkSelector._poll_once

def aggressive_poll_once(self, timeout=None):
gc.collect()
return orig_poll_once(self, timeout)
monkeypatch.setattr(NetworkSelector, '_poll_once', aggressive_poll_once)

async def hangs_then_times_out():
# Awaits a bare Future that nothing references externally --
# exactly the orphan-cycle shape that CPython's gc collects.
await Future()

# wait_for should fail with KafkaTimeoutError, not GeneratorExit.
async def waiter():
inner = manager.call_soon(hangs_then_times_out)
return await manager.wait_for(inner, timeout_ms=50)

with pytest.raises(Errors.KafkaTimeoutError):
manager.run(waiter)
Loading