diff --git a/kafka/net/connection.py b/kafka/net/connection.py index f7333002d..9ca0547e4 100644 --- a/kafka/net/connection.py +++ b/kafka/net/connection.py @@ -198,7 +198,7 @@ def data_received(self, data): self.unpause('max_in_flight') if self.in_flight_requests: next_timeout_at = self.in_flight_requests[0][3] - self.net.reschedule(self._timeout_task, next_timeout_at) + self.net.reschedule(next_timeout_at, self._timeout_task) else: self.net.unschedule(self._timeout_task) self._timeout_task = None diff --git a/kafka/net/selector.py b/kafka/net/selector.py index b89dcd31a..0198367cf 100644 --- a/kafka/net/selector.py +++ b/kafka/net/selector.py @@ -183,9 +183,11 @@ def call_soon(self, task): return task def unschedule(self, task): - self._scheduled.remove((task.scheduled_at, task)) + if task.scheduled_at is not None: + self._scheduled.remove((task.scheduled_at, task)) + task.scheduled_at = None - def reschedule(self, task, when): + def reschedule(self, when, task): self.unschedule(task) self.call_at(when, task) return task @@ -215,7 +217,9 @@ def _wait_read(self, fileobj): def _schedule_tasks(self): while self._scheduled and self._scheduled[0][0] <= time.monotonic(): - self._ready.append(heapq.heappop(self._scheduled)[1]) + _, task = heapq.heappop(self._scheduled) + task.scheduled_at = None + self._ready.append(task) def _next_scheduled_timeout(self, now): try: diff --git a/test/net/test_selector.py b/test/net/test_selector.py index 39aa9e6fe..727233613 100644 --- a/test/net/test_selector.py +++ b/test/net/test_selector.py @@ -332,6 +332,15 @@ def task(): assert len(net._scheduled) == 1 net.unschedule(t) assert len(net._scheduled) == 0 + assert t.scheduled_at is None + + def test_unschedule_unscheduled(self): + net = NetworkSelector() + def task(): + yield + assert len(net._scheduled) == 0 + net.unschedule(Task(task)) + assert len(net._scheduled) == 0 def test_reschedule(self): net = NetworkSelector() @@ -339,7 +348,16 @@ def task(): yield t = net.call_later(10, task) new_when = time.monotonic() + 0.01 - net.reschedule(t, new_when) + net.reschedule(new_when, t) + assert len(net._scheduled) == 1 + assert net._scheduled[0][0] == new_when + + def test_reschedule_unscheduled(self): + net = NetworkSelector() + def task(): + yield + new_when = time.monotonic() + 0.01 + net.reschedule(new_when, Task(task)) assert len(net._scheduled) == 1 assert net._scheduled[0][0] == new_when