From 677980f99db989c7e7866c9355e601574d241e0a Mon Sep 17 00:00:00 2001 From: Aditya Ganti Date: Sun, 5 Apr 2026 11:06:42 -0400 Subject: [PATCH] perf: coalesce empty checkpoints in batch Empty checkpoints (used by map/parallel branch resubmitters when resuming from timed waits) no longer count toward the 250-operation batch limit beyond the first. This prevents 300+ concurrent branch resumes from splitting across multiple API batches. closes #325 Co-authored-by: Aditya Ganti --- src/aws_durable_execution_sdk_python/state.py | 102 +++++--- .../e2e/map_with_concurrent_waits_int_test.py | 200 ++++++++++++++++ tests/state_test.py | 221 ++++++++++++++++++ 3 files changed, 493 insertions(+), 30 deletions(-) create mode 100644 tests/e2e/map_with_concurrent_waits_int_test.py diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 8f8ff82..16bd31d 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -592,15 +592,21 @@ def checkpoint_batches_forever(self) -> None: batch: list[QueuedOperation] = self._collect_checkpoint_batch() if batch: - # Extract OperationUpdates from QueuedOperations for API call - updates: list[OperationUpdate] = [ - q.operation_update for q in batch if q.operation_update is not None - ] + # Extract OperationUpdates, excluding empty checkpoints from API call + updates: list[OperationUpdate] = [] + empty_count = 0 + + for q in batch: + if q.operation_update is not None: + updates.append(q.operation_update) + else: + empty_count += 1 logger.debug( - "Processing checkpoint batch with %d operations (%d non-empty)", - len(batch), + "Sending %d OperationUpdates out of %d operations, excluding %d empty checkpoints", len(updates), + len(batch), + empty_count, ) try: @@ -687,26 +693,43 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]: operation if queues are empty, then collects additional operations within the time window. + Empty checkpoints (operation_update=None) are coalesced: the first empty checkpoint + counts toward the batch operation limit, but subsequent empty checkpoints do not. + All empty checkpoints remain in the batch so their completion events are signaled. + This avoids unnecessary batches when many concurrent map/parallel branches resume + simultaneously and each queues an empty checkpoint. + Returns: List of QueuedOperation objects ready for batch processing. Returns empty list if no operations are available. """ batch: list[QueuedOperation] = [] + has_empty_checkpoint = False total_size = 0 + effective_operation_count = 0 # Operations that count toward batch limit # First, drain overflow queue (FIFO order preserved) try: - while len(batch) < self._batcher_config.max_batch_operations: + while effective_operation_count < self._batcher_config.max_batch_operations: overflow_op = self._overflow_queue.get_nowait() - op_size = self._calculate_operation_size(overflow_op) - - if total_size + op_size > self._batcher_config.max_batch_size_bytes: - # Put back and stop - self._overflow_queue.put(overflow_op) - break - batch.append(overflow_op) - total_size += op_size + if overflow_op.operation_update is None: # Empty checkpoint + batch.append(overflow_op) + if not has_empty_checkpoint: + effective_operation_count += ( + 1 # First empty counts toward limit + ) + has_empty_checkpoint = True + # Subsequent empties don't count toward limit + else: + op_size = self._calculate_operation_size(overflow_op) + if total_size + op_size > self._batcher_config.max_batch_size_bytes: + # Put back and stop + self._overflow_queue.put(overflow_op) + break + batch.append(overflow_op) + total_size += op_size + effective_operation_count += 1 except queue.Empty: pass @@ -720,7 +743,13 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]: ) # Check stop signal every 100ms self._checkpoint_queue.task_done() batch.append(first_op) - total_size += self._calculate_operation_size(first_op) + + if first_op.operation_update is None: + has_empty_checkpoint = True + else: + total_size += self._calculate_operation_size(first_op) + + effective_operation_count = 1 break except queue.Empty: continue @@ -735,7 +764,7 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]: # Collect additional operations within the time window while ( time.time() < batch_deadline - and len(batch) < self._batcher_config.max_batch_operations + and effective_operation_count < self._batcher_config.max_batch_operations and not self._checkpointing_stopped.is_set() ): remaining_time = min( @@ -749,26 +778,39 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]: try: additional_op = self._checkpoint_queue.get(timeout=remaining_time) self._checkpoint_queue.task_done() - op_size = self._calculate_operation_size(additional_op) - - # Check if adding this operation would exceed size limit - if total_size + op_size > self._batcher_config.max_batch_size_bytes: - # Put in overflow queue for next batch - self._overflow_queue.put(additional_op) - logger.debug( - "Batch size limit reached, moving operation to overflow queue" - ) - break - batch.append(additional_op) - total_size += op_size + if additional_op.operation_update is None: # Empty checkpoint + batch.append(additional_op) + if not has_empty_checkpoint: + effective_operation_count += ( + 1 # First empty counts toward limit + ) + has_empty_checkpoint = True + # Subsequent empties don't count toward limit + else: + op_size = self._calculate_operation_size(additional_op) + # Check if adding this operation would exceed size limit + if total_size + op_size > self._batcher_config.max_batch_size_bytes: + # Put in overflow queue for next batch + self._overflow_queue.put(additional_op) + logger.debug( + "Batch size limit reached, moving operation to overflow queue" + ) + break + batch.append(additional_op) + total_size += op_size + effective_operation_count += 1 except queue.Empty: break + empty_count = sum(1 for q in batch if q.operation_update is None) logger.debug( - "Collected batch of %d operations, total size: %d bytes", + "Collected batch of %d operations (%d effective, %d non-empty, %d empty), total size: %d bytes", len(batch), + effective_operation_count, + len(batch) - empty_count, + empty_count, total_size, ) return batch diff --git a/tests/e2e/map_with_concurrent_waits_int_test.py b/tests/e2e/map_with_concurrent_waits_int_test.py new file mode 100644 index 0000000..8ad812e --- /dev/null +++ b/tests/e2e/map_with_concurrent_waits_int_test.py @@ -0,0 +1,200 @@ +"""Integration test: empty checkpoint coalescing with concurrent map + wait. + +Python equivalent of the Java MapWithConditionAndCallbackExample referenced in +issue #325. Verifies that when many concurrent map branches resume from timed +wait operations simultaneously, the empty checkpoints produced by the +resubmitter (executor.py) are coalesced into minimal API calls instead of +being split across multiple batches. + +Background +---------- +When a map branch suspends via TimedSuspendExecution and later resumes, the +ConcurrentExecutor resubmitter calls:: + + execution_state.create_checkpoint() # empty checkpoint + +before resubmitting the branch. In high-concurrency scenarios (300+ branches) +all resuming at the same time, 300+ empty checkpoints flood the checkpoint +queue. + +Without the coalescing optimization (issue #325), the 250-operation batch limit +causes these to be split across multiple batches → multiple API calls. +With the optimization, all subsequent empty checkpoints beyond the first do +NOT count toward the batch limit, so they are coalesced into a single batch +and a single API call. + +These tests directly simulate that concurrent-checkpoint pattern by launching +many threads that each call ``create_checkpoint()`` simultaneously, mirroring +what the map resubmitter does when all branches resume at once. +""" + +from __future__ import annotations + +import threading +from concurrent.futures import ThreadPoolExecutor + + +from aws_durable_execution_sdk_python.lambda_service import ( + CheckpointOutput, + CheckpointUpdatedExecutionState, + LambdaClient, + OperationAction, + OperationUpdate, + OperationType, +) +from aws_durable_execution_sdk_python.state import ( + CheckpointBatcherConfig, + ExecutionState, + QueuedOperation, +) +from aws_durable_execution_sdk_python.threading import CompletionEvent + +from unittest.mock import Mock + + +def _make_state( + mock_client: Mock, + batch_time: float = 5.0, + max_ops: int = 250, +) -> ExecutionState: + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=batch_time, + max_batch_operations=max_ops, + ) + return ExecutionState( + durable_execution_arn="test-arn", + initial_checkpoint_token="token-0", # noqa: S106 + operations={}, + service_client=mock_client, + batcher_config=config, + ) + + +def _make_tracking_client() -> tuple[Mock, list]: + """Return a (mock LambdaClient, checkpoint_calls list) pair.""" + calls: list[list] = [] + mock_client = Mock(spec=LambdaClient) + + def _checkpoint( + durable_execution_arn, checkpoint_token, updates, client_token=None + ): + calls.append(list(updates)) + return CheckpointOutput( + checkpoint_token=f"token_{len(calls)}", + new_execution_state=CheckpointUpdatedExecutionState(), + ) + + mock_client.checkpoint = _checkpoint + return mock_client, calls + + +def test_map_with_concurrent_waits_coalesces_empty_checkpoints(): + """300 concurrent branches all create empty checkpoints simultaneously. + + Simulates the Java MapWithConditionAndCallbackExample scenario: 300 map + branches all resuming from a wait operation at the same time, each calling + the resubmitter which enqueues an empty checkpoint. + + Without the coalescing optimization, the 250-op batch limit splits 300 + empty checkpoints into 2 batches (250 + 50) → 2 API calls. + With the optimization (effective_operation_count stays 1 for empties), + all 300 are collected in a single batch → 1 API call. + """ + mock_client, calls = _make_tracking_client() + state = _make_state(mock_client, batch_time=5.0, max_ops=250) + + batcher = ThreadPoolExecutor(max_workers=1) + batcher.submit(state.checkpoint_batches_forever) + + # 300 branches all call create_checkpoint() concurrently, each blocking + # until the batch is processed — mirrors the resubmitter pattern. + branch_count = 300 + start_barrier = threading.Barrier(branch_count) + errors: list[Exception] = [] + + def branch_work(): + try: + start_barrier.wait() # all start simultaneously + state.create_checkpoint() # empty checkpoint, synchronous + except Exception as e: # noqa: BLE001 + errors.append(e) + + threads = [threading.Thread(target=branch_work) for _ in range(branch_count)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + + try: + assert not errors, f"Branch errors: {errors}" + + # All 300 empty checkpoints should be batched into 1 API call. + # Without the fix, 300 > 250 limit would produce 2 calls. + assert len(calls) == 1, ( + f"Expected 1 coalesced API call for {branch_count} concurrent empty " + f"checkpoints, got {len(calls)}. The 250-op limit must not split empties." + ) + assert calls[0] == [], "Empty checkpoints should produce an empty updates list" + finally: + state.stop_checkpointing() + batcher.shutdown(wait=True) + + +def test_map_with_concurrent_waits_api_call_count_scales_with_real_ops_not_empties(): + """400 empty checkpoints + 10 real ops → 1 API call with limit=11. + + Demonstrates that the effective batch count is driven by real operations + (and only the *first* empty), not the total number of empties. + + With limit=11: the first empty counts as effective_op 1, and each of the + 10 real ops increments the count (effective_ops 2–11). The limit is hit + exactly when the last real op is collected. All 399 remaining empties are + coalesced in without incrementing the count. + + Result: 1 batch (410 operations, 10 real) → 1 API call. + """ + mock_client, calls = _make_tracking_client() + # limit = 1 (first empty) + 10 (real ops) = 11, so all fit in one batch + state = _make_state(mock_client, batch_time=5.0, max_ops=11) + + completion_events: list[CompletionEvent] = [] + + try: + # 400 empty checkpoints (simulating concurrent branch resumes) + for _ in range(400): + ev = CompletionEvent() + completion_events.append(ev) + state._checkpoint_queue.put(QueuedOperation(None, ev)) # noqa: SLF001 + + # 10 real operations alongside the empties + for i in range(10): + op = OperationUpdate( + operation_id=f"op_{i}", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + + ev = CompletionEvent() + completion_events.append(ev) + state._checkpoint_queue.put(QueuedOperation(op, ev)) # noqa: SLF001 + + batcher = ThreadPoolExecutor(max_workers=1) + batcher.submit(state.checkpoint_batches_forever) + + # Wait for all 410 to be processed + for ev in completion_events: + ev.wait() + + # 1 empty (effective=1) + 10 real ops (effective=11) exhaust the batch + # limit exactly. The 399 remaining empties coalesce in → still 1 API call. + assert len(calls) == 1, ( + f"Expected 1 API call with 400 empty + 10 real ops (limit=11), " + f"got {len(calls)}." + ) + # Only the 10 real ops appear in the updates list; empties are excluded. + real_op_ids = {u.operation_id for batch in calls for u in batch} + assert real_op_ids == {f"op_{i}" for i in range(10)} + finally: + state.stop_checkpointing() + batcher.shutdown(wait=True) diff --git a/tests/state_test.py b/tests/state_test.py index 019b001..1f49838 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -3341,3 +3341,224 @@ def test_state_replay_mode_with_timed_out(): assert execution_state.is_replaying() is True execution_state.track_replay(operation_id="op2") assert execution_state.is_replaying() is False + + +# Tests for empty checkpoint coalescing (issue #325) + + +def test_collect_checkpoint_batch_coalesces_many_empty_checkpoints(): + """Test that many empty checkpoints are collected into a single batch. + + With the coalescing optimization, 999 empty checkpoints should all be collected + in one batch (effective_operation_count=1), not split across 4 batches of 250. + """ + mock_lambda_client = Mock(spec=LambdaClient) + + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=250, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + # Enqueue 999 empty checkpoints (simulates high-concurrency map/parallel resume) + for _ in range(999): + state._checkpoint_queue.put(QueuedOperation(None, None)) + + # All 999 should be collected in a single batch + batch = state._collect_checkpoint_batch() + + assert len(batch) == 999 + assert all(q.operation_update is None for q in batch) + # Queue should now be empty + assert state._checkpoint_queue.empty() + + +def test_collect_checkpoint_batch_empty_checkpoints_with_real_ops_respects_limit(): + """Test that real operations still respect the max_batch_operations limit + even when many empty checkpoints are present in the same batch. + """ + mock_lambda_client = Mock(spec=LambdaClient) + + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=5, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + # Enqueue 3 empty checkpoints + 10 real operations + for _ in range(3): + state._checkpoint_queue.put(QueuedOperation(None, None)) + for i in range(10): + op = OperationUpdate( + operation_id=f"op_{i}", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + state._checkpoint_queue.put(QueuedOperation(op, None)) + + batch = state._collect_checkpoint_batch() + + # Empty checkpoints count as 1 effective op, so 4 real ops fit in 5-op limit + empty_in_batch = sum(1 for q in batch if q.operation_update is None) + real_in_batch = sum(1 for q in batch if q.operation_update is not None) + + assert empty_in_batch == 3 # All empty checkpoints coalesced + assert real_in_batch == 4 # 4 real ops (1 slot used by the first empty) + + +def test_collect_checkpoint_batch_overflow_coalesces_empty_checkpoints(): + """Test that empty checkpoints in the overflow queue are also coalesced.""" + mock_lambda_client = Mock(spec=LambdaClient) + + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=250, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + # Put 500 empty checkpoints directly into the overflow queue + for _ in range(500): + state._overflow_queue.put(QueuedOperation(None, None)) + + # All 500 should be collected in a single batch from overflow + batch = state._collect_checkpoint_batch() + + assert len(batch) == 500 + assert all(q.operation_update is None for q in batch) + assert state._overflow_queue.empty() + + +def test_checkpoint_batches_forever_single_api_call_for_many_empty_checkpoints(): + """Test that many empty checkpoints result in a single API call, not one per batch. + + This is the core optimization: 999 empty checkpoints should produce exactly 1 API + call instead of ceil(999/250) = 4 API calls. + """ + mock_lambda_client = Mock(spec=LambdaClient) + mock_lambda_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + + # Long time window to ensure all 999 pre-queued items are drained in one batch. + # The optimization is about the operation COUNT limit, not the time limit. + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=5.0, + max_batch_operations=250, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + # Enqueue 999 empty checkpoints with completion events + completion_events = [] + for _ in range(999): + event = CompletionEvent() + completion_events.append(event) + state._checkpoint_queue.put(QueuedOperation(None, event)) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + # Wait for all completion events to be signaled + for event in completion_events: + event.wait() + + # All 999 empty checkpoints should have been batched into a single API call + # (before the fix, the 250-item limit would split them into 4 batches) + assert mock_lambda_client.checkpoint.call_count == 1 + # The API call should have been made with an empty updates list + call_kwargs = mock_lambda_client.checkpoint.call_args + assert call_kwargs.kwargs["updates"] == [] + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_collect_checkpoint_batch_first_empty_counts_toward_limit(): + """Test that only the first empty checkpoint counts toward the batch operation limit. + + With limit=2: an empty op (effective=1) + a real op (effective=2) exactly fills the + batch. The loop exits after the limit is hit; items after the limit stay in the queue. + """ + mock_lambda_client = Mock(spec=LambdaClient) + + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=2, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + # Queue: 1 empty (effective=1), op_1 (effective=2, hits limit), + # op_2 + trailing empties remain in queue for next batch. + op1 = OperationUpdate( + operation_id="op_1", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + op2 = OperationUpdate( + operation_id="op_2", + operation_type=OperationType.STEP, + action=OperationAction.START, + ) + state._checkpoint_queue.put(QueuedOperation(None, None)) # empty — effective=1 + state._checkpoint_queue.put( + QueuedOperation(op1, None) + ) # real — effective=2, limit hit + state._checkpoint_queue.put(QueuedOperation(op2, None)) # real — stays in queue + + for _ in range(50): + state._checkpoint_queue.put(QueuedOperation(None, None)) # trailing empties + + batch = state._collect_checkpoint_batch() + + real_in_batch = [q for q in batch if q.operation_update is not None] + empty_in_batch = [q for q in batch if q.operation_update is None] + + # The batch contains exactly: 1 leading empty + op_1 (limit=2 effective ops) + assert len(real_in_batch) == 1 + assert real_in_batch[0].operation_update.operation_id == "op_1" + assert ( + len(empty_in_batch) == 1 + ) # Only the leading empty; trailing deferred to next batch + # op_2 and trailing empties remain in the queue + assert state._checkpoint_queue.qsize() == 51