diff --git a/kafka/producer/future.py b/kafka/producer/future.py index a5096ac48..a97acbad5 100644 --- a/kafka/producer/future.py +++ b/kafka/producer/future.py @@ -3,6 +3,7 @@ from kafka import errors as Errors from kafka.future import Future +from kafka.util import Timer class FutureProduceResult(Future): @@ -57,11 +58,38 @@ def _produce_success(self, result): serialized_value_size, serialized_header_size) self.success(metadata) + def rebind(self, new_produce_future, new_batch_index): + """Rebind this future to a new produce future with a new batch index. + + Used when a batch is split due to MESSAGE_TOO_LARGE. The original + FutureRecordMetadata is rebound to the new (smaller) batch's future. + + This must be called from the sender thread while the old produce_future + has not been completed. Any user thread blocked in get() on the old + produce_future's latch will be woken and will re-wait on the new one. + """ + old_produce_future = self._produce_future + self._produce_future = new_produce_future + _, timestamp_ms, checksum, sk, sv, sh = self.args + self.args = (new_batch_index, timestamp_ms, checksum, sk, sv, sh) + new_produce_future.add_callback(self._produce_success) + new_produce_future.add_errback(self.failure) + # Wake any thread blocked in get() so it re-waits on the new future. + # The old produce_future is never completed, so its stale callbacks + # (registered in __init__) will never fire. + old_produce_future._latch.set() + def get(self, timeout=None): - if not self.is_done and not self._produce_future.wait(timeout): - raise Errors.KafkaTimeoutError( - "Timeout after waiting for %s secs." % (timeout,)) - assert self.is_done + """Wait for up to timeout seconds for future to complete.""" + # Loop because rebind() may wake us from the old produce_future's + # latch before the record is actually done. A batch may be split + # multiple times, so each rebind wakes us and we re-wait on the + # (possibly new) _produce_future. + timer = Timer(timeout * 1000 if timeout is not None else None) + while not self.is_done and not timer.expired: + if not self._produce_future.wait(timer.timeout_secs): + raise Errors.KafkaTimeoutError( + "Timeout after waiting for %s secs." % (timeout,)) if self.failed(): raise self.exception # pylint: disable-msg=raising-bad-type return self.value diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index 2ea94e47c..1dba19664 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -661,7 +661,9 @@ def max_usable_produce_magic(cls, api_version): else: return 0 - def _estimate_size_in_bytes(self, key, value, headers=[]): + def _estimate_size_in_bytes(self, key, value, headers=None): + if headers is None: + headers = [] magic = self.max_usable_produce_magic(self.config['api_version']) if magic == 2: return DefaultRecordBatchBuilder.estimate_size_in_bytes( diff --git a/kafka/producer/producer_batch.py b/kafka/producer/producer_batch.py index 1f5edb80d..bb83ae2e0 100644 --- a/kafka/producer/producer_batch.py +++ b/kafka/producer/producer_batch.py @@ -27,6 +27,7 @@ def __init__(self, tp, records, now=None): self.records = records self.topic_partition = tp self.produce_future = FutureProduceResult(tp) + self._record_futures = [] self._retry = False self._final_state = None @@ -66,6 +67,7 @@ def try_append(self, timestamp_ms, key, value, headers, now=None): len(key) if key is not None else -1, len(value) if value is not None else -1, sum(len(h_key.encode("utf-8")) + len(h_val) for h_key, h_val in headers) if headers else -1) + self._record_futures.append(future) return future def abort(self, exception): diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index 15117dd98..db6f8c114 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -6,7 +6,7 @@ import kafka.errors as Errors from kafka.producer.producer_batch import ProducerBatch -from kafka.record.memory_records import MemoryRecordsBuilder +from kafka.record.memory_records import MemoryRecords, MemoryRecordsBuilder from kafka.structs import TopicPartition @@ -206,6 +206,88 @@ def expired_batches(self, now=None): break return expired_batches + def split_and_reenqueue(self, batch, now=None): + """Split an oversized batch into smaller batches and reenqueue them. + + When a produce request fails with MESSAGE_TOO_LARGE, this method splits + the batch into two sub-batches (by record count) and enqueues them at + the front of the partition's deque. The original FutureRecordMetadata + objects are rebound to the new batches' futures. + + If the new batches are still too large, they will be split again on the + next MESSAGE_TOO_LARGE response. + + Only supported for message_version >= 2 (DefaultRecordBatch). + + Arguments: + batch (ProducerBatch): The oversized batch to split. + + Returns: + int: The number of new batches created. + """ + now = time.monotonic() if now is None else now + tp = batch.topic_partition + + # Read all records from the closed batch + records_list = [] + for record_batch in MemoryRecords(batch.records.buffer()): + for record in record_batch: + records_list.append(record) + + # Split records into two halves by count + mid = (len(records_list) + 1) // 2 + groups = [records_list[:mid], records_list[mid:]] + + new_batches = [] + future_index = 0 + for group in groups: + if not group: + continue + builder = MemoryRecordsBuilder( + self.config['message_version'], + self.config['compression_attrs'], + self.config['batch_size'], + ) + current_batch = ProducerBatch(tp, builder, now=now) + current_batch.created = batch.created + + for record in group: + metadata = builder.append(record.timestamp, record.key, record.value, record.headers) + if metadata is None: + # Record doesn't fit (extremely unlikely for split batches). + # Finalize this batch and start a new one. + new_batches.append(current_batch) + builder = MemoryRecordsBuilder( + self.config['message_version'], + self.config['compression_attrs'], + self.config['batch_size'], + ) + current_batch = ProducerBatch(tp, builder, now=now) + current_batch.created = batch.created + metadata = builder.append(record.timestamp, record.key, record.value, record.headers) + + # Rebind original future to new batch + if future_index < len(batch._record_futures): + original_future = batch._record_futures[future_index] + original_future.rebind(current_batch.produce_future, metadata.offset) + current_batch._record_futures.append(original_future) + future_index += 1 + + new_batches.append(current_batch) + + # Enqueue in reverse order so first batch is at front of deque + with self._tp_lock(tp): + dq = self._batches[tp] + for new_batch in reversed(new_batches): + new_batch.attempts = batch.attempts + new_batch.last_attempt = now + dq.appendleft(new_batch) + self._incomplete.add(new_batch) + + log.info("Split oversized batch for %s into %d new batches (%d total records)", + tp, len(new_batches), future_index) + return len(new_batches) + def reenqueue(self, batch, now=None): """ Re-enqueue the given record batch in the accumulator. In Sender._complete_batch method, we check diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 6ec27f71d..64ae5eb0e 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -507,7 +507,15 @@ def _complete_batch(self, batch, partition_response): error = None if error is not None: - if self._can_retry(batch, error): + if self._can_split(batch, error): + log.warning("%s: Got %s on topic-partition %s with %d records, splitting batch and retrying", + str(self), error.__name__, batch.topic_partition, batch.record_count) + self._accumulator.split_and_reenqueue(batch) + self._maybe_remove_from_inflight_batches(batch) + self._accumulator.deallocate(batch) + if self._sensors: + self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) + elif self._can_retry(batch, error): # retry log.warning("%s: Got error produce response on topic-partition %s, retrying (%s attempts left): %s%s", str(self), batch.topic_partition, @@ -566,6 +574,17 @@ def _can_retry(self, batch, error): batch.final_state is None and getattr(error, 'retriable', False)) + def _can_split(self, batch, error): + """ + We can split and retry a batch if the error indicates the batch is too + large for the broker, the batch contains more than one record (so it + can actually be split), and the delivery timeout has not been reached. + """ + return (error in (Errors.MessageSizeTooLargeError, Errors.RecordListTooLargeError) and + batch.record_count > 1 and + batch.final_state is None and + not batch.has_reached_delivery_timeout(self._accumulator.delivery_timeout_ms)) + def _create_produce_requests(self, collated): """ Transfer the record batches into a list of produce requests on a diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index 3ef2c3bfc..0056a9951 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -158,11 +158,13 @@ def skip(self, offsets_to_skip): # Exposed for testing compacted records self._next_offset += offsets_to_skip - def append(self, timestamp, key, value, headers=[]): + def append(self, timestamp, key, value, headers=None): """ Append a message to the buffer. Returns: RecordMetadata or None if unable to append """ + if headers is None: + headers = [] if self._closed: return None diff --git a/kafka/util.py b/kafka/util.py index a09d02d2d..84c6d5fd1 100644 --- a/kafka/util.py +++ b/kafka/util.py @@ -35,6 +35,11 @@ def timeout_ms(self): else: return int(remaining * 1000) + @property + def timeout_secs(self): + timeout_ms = self.timeout_ms + return timeout_ms / 1000 if timeout_ms is not None else None + @property def elapsed_ms(self): return int(1000 * (time.monotonic() - self._start_at)) diff --git a/test/test_sender.py b/test/test_sender.py index 35fb9f165..372f0a0cf 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -267,3 +267,420 @@ def test__record_exceptions_fn(sender): record_exceptions_fn = sender._record_exceptions_fn(Errors.KafkaError('top-level'), [(0, 'err-0')], 'message') assert record_exceptions_fn(0) == Errors.KafkaError('err-0') + + +def multi_record_batch(num_records=5, topic='foo', partition=0, batch_size=100000): + """Create a ProducerBatch with multiple records for split testing.""" + tp = TopicPartition(topic, partition) + records = MemoryRecordsBuilder(magic=2, compression_type=0, batch_size=batch_size) + batch = ProducerBatch(tp, records) + futures = [] + for i in range(num_records): + future = batch.try_append(0, b'key-%d' % i, b'value-%d' % i, []) + futures.append(future) + batch.records.close() + return batch, futures + + +def test_can_split(): + """_can_split returns True for MESSAGE_TOO_LARGE with >1 record.""" + from kafka.producer.sender import Sender + batch, _ = multi_record_batch(num_records=5) + assert batch.record_count == 5 + + # _can_split is a bound method, so we test the logic directly + assert (Errors.MessageSizeTooLargeError in (Errors.MessageSizeTooLargeError, Errors.RecordListTooLargeError) + and batch.record_count > 1 + and batch.final_state is None + and not batch.has_reached_delivery_timeout(120000)) + + # Single record should not be splittable + batch1, _ = multi_record_batch(num_records=1) + assert batch1.record_count == 1 + assert not (batch1.record_count > 1) + + +def test_can_split_method(sender): + batch, _ = multi_record_batch(num_records=5) + assert sender._can_split(batch, Errors.MessageSizeTooLargeError) + assert sender._can_split(batch, Errors.RecordListTooLargeError) + assert not sender._can_split(batch, Errors.KafkaConnectionError) + assert not sender._can_split(batch, Errors.NotLeaderForPartitionError) + + # Single record: cannot split + batch1, _ = multi_record_batch(num_records=1) + assert not sender._can_split(batch1, Errors.MessageSizeTooLargeError) + + +def test_can_split_delivery_timeout(sender): + batch, _ = multi_record_batch(num_records=5) + # Simulate expired batch + batch.created = time.monotonic() - 999999 + assert not sender._can_split(batch, Errors.MessageSizeTooLargeError) + + +def test_split_and_reenqueue(accumulator): + """RecordAccumulator.split_and_reenqueue splits a batch and enqueues new batches.""" + tp = TopicPartition('foo', 0) + batch, futures = multi_record_batch(num_records=10) + # Add batch to incomplete tracking (normally done during append) + accumulator._incomplete.add(batch) + + num_new = accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + assert num_new >= 2 # Should produce at least 2 new batches + # Check that new batches are in the deque + dq = accumulator._batches[tp] + assert len(dq) == num_new + + total_records = sum(b.record_count for b in dq) + assert total_records == 10 + + +def test_split_and_reenqueue_preserves_creation_time(accumulator): + """Split batches preserve the original batch's creation time for delivery timeout.""" + tp = TopicPartition('foo', 0) + batch, _ = multi_record_batch(num_records=4) + original_created = batch.created + accumulator._incomplete.add(batch) + + accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + for new_batch in accumulator._batches[tp]: + assert new_batch.created == original_created + + +def test_split_and_reenqueue_preserves_attempts(accumulator): + """Split batches inherit the original batch's attempt count.""" + tp = TopicPartition('foo', 0) + batch, _ = multi_record_batch(num_records=4) + batch.attempts = 3 + accumulator._incomplete.add(batch) + + accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + for new_batch in accumulator._batches[tp]: + assert new_batch.attempts == 3 + + +def test_split_future_rebinding(accumulator): + """After split, original futures resolve when new batches complete.""" + tp = TopicPartition('foo', 0) + batch, futures = multi_record_batch(num_records=4) + accumulator._incomplete.add(batch) + + accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + # Complete each new batch and verify original futures resolve + dq = accumulator._batches[tp] + base_offset = 100 + record_idx = 0 + for new_batch in list(dq): + new_batch.complete(base_offset, -1) + for i in range(new_batch.record_count): + future = futures[record_idx] + assert future.is_done, "Future %d should be resolved" % record_idx + assert future.succeeded(), "Future %d should have succeeded" % record_idx + metadata = future.value + assert metadata.offset == base_offset + i + record_idx += 1 + base_offset += 1000 + + assert record_idx == 4 + + +def test_split_future_rebinding_on_error(accumulator): + """After split, if a new batch fails, the original futures for those records fail.""" + tp = TopicPartition('foo', 0) + batch, futures = multi_record_batch(num_records=4) + accumulator._incomplete.add(batch) + + accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + dq = accumulator._batches[tp] + # Fail all new batches + for new_batch in list(dq): + error = Errors.KafkaError("test error") + new_batch.complete_exceptionally(error, lambda _: error) + + for future in futures: + assert future.is_done + assert future.failed() + assert isinstance(future.exception, Errors.KafkaError) + + +def test_complete_batch_splits_on_message_too_large(sender, accumulator, mocker): + """_complete_batch splits batch on MESSAGE_TOO_LARGE instead of failing.""" + tp = TopicPartition('foo', 0) + batch, futures = multi_record_batch(num_records=5) + accumulator._incomplete.add(batch) + + sender._complete_batch(batch, PartitionResponse(error=Errors.MessageSizeTooLargeError)) + + # Original batch should be deallocated (not in incomplete set) + assert batch not in accumulator._incomplete.all() + + # New batches should be enqueued + dq = accumulator._batches[tp] + assert len(dq) >= 2 + + total_records = sum(b.record_count for b in dq) + assert total_records == 5 + + # Original futures should not be done yet (new batches haven't been sent) + for future in futures: + assert not future.is_done + + +def test_complete_batch_splits_on_record_list_too_large(sender, accumulator, mocker): + """_complete_batch splits batch on RECORD_LIST_TOO_LARGE.""" + tp = TopicPartition('foo', 0) + batch, futures = multi_record_batch(num_records=5) + accumulator._incomplete.add(batch) + + sender._complete_batch(batch, PartitionResponse(error=Errors.RecordListTooLargeError)) + + dq = accumulator._batches[tp] + assert len(dq) >= 2 + total_records = sum(b.record_count for b in dq) + assert total_records == 5 + + +def test_complete_batch_single_record_fails_normally(sender, accumulator): + """Single-record batch with MESSAGE_TOO_LARGE fails (cannot split).""" + batch, futures = multi_record_batch(num_records=1) + accumulator._incomplete.add(batch) + sender.config['retries'] = 0 + + sender._complete_batch(batch, PartitionResponse(error=Errors.MessageSizeTooLargeError)) + + assert batch.is_done + assert futures[0].is_done + assert futures[0].failed() + assert isinstance(futures[0].exception, Errors.MessageSizeTooLargeError) + + +def test_complete_batch_split_unmutes_partition(sender, accumulator): + """After splitting, the partition should be unmuted for guarantee_message_order.""" + tp = TopicPartition('foo', 0) + sender.config['guarantee_message_order'] = True + accumulator.muted.add(tp) + + batch, _ = multi_record_batch(num_records=5, topic='foo', partition=0) + accumulator._incomplete.add(batch) + + sender._complete_batch(batch, PartitionResponse(error=Errors.MessageSizeTooLargeError)) + + assert tp not in accumulator.muted + + +def test_split_not_in_retry(accumulator): + """Split batches should not be marked as in_retry so sequence numbers are assigned during drain.""" + tp = TopicPartition('foo', 0) + batch, _ = multi_record_batch(num_records=4) + accumulator._incomplete.add(batch) + + accumulator.split_and_reenqueue(batch) + + for new_batch in accumulator._batches[tp]: + assert not new_batch.in_retry() + + +def test_split_with_small_batch_size(): + """When batch_size is small, records are distributed across more batches.""" + # Use a small batch_size to force many splits + accumulator = RecordAccumulator(batch_size=100) + tp = TopicPartition('foo', 0) + + # Create a batch with large batch_size (simulating the original oversized batch) + batch, futures = multi_record_batch(num_records=10, batch_size=100000) + accumulator._incomplete.add(batch) + + num_new = accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + dq = accumulator._batches[tp] + total_records = sum(b.record_count for b in dq) + assert total_records == 10 + # With 100 byte batch_size, we expect many batches + assert num_new >= 2 + + +def test_future_rebind(): + """FutureRecordMetadata.rebind updates produce_future and batch_index.""" + from kafka.producer.future import FutureProduceResult, FutureRecordMetadata + tp = TopicPartition('foo', 0) + + old_pf = FutureProduceResult(tp) + new_pf = FutureProduceResult(tp) + + future = FutureRecordMetadata(old_pf, 5, 1000, None, 3, 5, -1) + assert future._produce_future is old_pf + assert future.args[0] == 5 # batch_index + + future.rebind(new_pf, 2) + assert future._produce_future is new_pf + assert future.args[0] == 2 # new batch_index + + # Complete new produce future and verify the record future resolves + new_pf.success((100, -1, None)) + assert future.is_done + assert future.succeeded() + assert future.value.offset == 102 # base_offset(100) + batch_index(2) + + +def test_rebind_sets_old_latch(): + """rebind() sets the old produce_future's latch so blocked get() threads wake up.""" + from kafka.producer.future import FutureProduceResult, FutureRecordMetadata + tp = TopicPartition('foo', 0) + + old_pf = FutureProduceResult(tp) + new_pf = FutureProduceResult(tp) + + future = FutureRecordMetadata(old_pf, 0, 1000, None, 3, 5, -1) + assert not old_pf._latch.is_set() + + future.rebind(new_pf, 0) + + # Old latch should be set so any thread blocked in get() wakes up + assert old_pf._latch.is_set() + # Future should not be resolved yet (new batch hasn't completed) + assert not future.is_done + + +def test_rebind_old_produce_future_callbacks_safe(): + """Old produce_future's stale callbacks don't crash if it is never completed.""" + from kafka.producer.future import FutureProduceResult, FutureRecordMetadata + tp = TopicPartition('foo', 0) + + old_pf = FutureProduceResult(tp) + new_pf = FutureProduceResult(tp) + + future = FutureRecordMetadata(old_pf, 0, 1000, None, 3, 5, -1) + future.rebind(new_pf, 0) + + # Complete the new produce_future — should resolve the record future once + new_pf.success((100, -1, None)) + assert future.is_done + assert future.succeeded() + + # The old produce_future should NOT be completed + assert not old_pf.is_done + + +def test_get_rewait_after_rebind(): + """get() re-waits on new produce_future after being woken by rebind().""" + import threading + from kafka.producer.future import FutureProduceResult, FutureRecordMetadata + tp = TopicPartition('foo', 0) + + old_pf = FutureProduceResult(tp) + future = FutureRecordMetadata(old_pf, 0, 1000, None, 3, 5, -1) + + result_holder = [None] + error_holder = [None] + + def get_in_thread(): + try: + result_holder[0] = future.get(timeout=5) + except Exception as e: + error_holder[0] = e + + t = threading.Thread(target=get_in_thread) + t.start() + + # Give the thread time to block on old_pf._latch.wait() + import time + time.sleep(0.05) + assert t.is_alive() + + # Rebind to a new produce_future — this wakes the blocked thread + new_pf = FutureProduceResult(tp) + future.rebind(new_pf, 0) + + # Thread should still be alive, now waiting on new_pf + time.sleep(0.05) + assert t.is_alive() + + # Complete the new produce_future + new_pf.success((42, -1, None)) + t.join(timeout=5) + assert not t.is_alive() + assert error_holder[0] is None + assert result_holder[0] is not None + assert result_holder[0].offset == 42 + + +def test_get_rewait_after_multiple_rebinds(): + """get() survives multiple rebinds (batch split more than once).""" + import threading + import time + from kafka.producer.future import FutureProduceResult, FutureRecordMetadata + tp = TopicPartition('foo', 0) + + pf1 = FutureProduceResult(tp) + future = FutureRecordMetadata(pf1, 0, 1000, None, 3, 5, -1) + + result_holder = [None] + error_holder = [None] + + def get_in_thread(): + try: + result_holder[0] = future.get(timeout=5) + except Exception as e: + error_holder[0] = e + + t = threading.Thread(target=get_in_thread) + t.start() + time.sleep(0.05) + + # First rebind (first split) + pf2 = FutureProduceResult(tp) + future.rebind(pf2, 0) + time.sleep(0.05) + assert t.is_alive() + + # Second rebind (second split) + pf3 = FutureProduceResult(tp) + future.rebind(pf3, 0) + time.sleep(0.05) + assert t.is_alive() + + # Finally complete + pf3.success((99, -1, None)) + t.join(timeout=5) + assert not t.is_alive() + assert error_holder[0] is None + assert result_holder[0].offset == 99 + + +def test_end_to_end_split_and_complete(accumulator): + """End-to-end: split a batch, complete new batches, verify all original futures resolve.""" + tp = TopicPartition('foo', 0) + batch, futures = multi_record_batch(num_records=8) + accumulator._incomplete.add(batch) + + accumulator.split_and_reenqueue(batch) + accumulator.deallocate(batch) + + dq = accumulator._batches[tp] + new_batches = list(dq) + + # Simulate sending and completing each new batch + offset = 0 + for new_batch in new_batches: + new_batch.complete(offset, -1) + offset += new_batch.record_count + + # All original futures should be resolved with correct offsets + for i, future in enumerate(futures): + assert future.is_done, "Future %d not done" % i + assert future.succeeded(), "Future %d failed: %s" % (i, future.exception) + assert future.value.offset == i + assert future.value.topic == 'foo' + assert future.value.partition == 0