diff --git a/kafka/__init__.py b/kafka/__init__.py index aff51c166..29131ace4 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -12,7 +12,9 @@ from kafka.admin import KafkaAdminClient from kafka.consumer import KafkaConsumer -from kafka.consumer.subscription_state import ConsumerRebalanceListener +from kafka.consumer.subscription_state import ( + AsyncConsumerRebalanceListener, ConsumerRebalanceListener, +) from kafka.producer import KafkaProducer from kafka.serializer import Serializer, Deserializer from kafka.structs import TopicPartition, TopicPartitionReplica, OffsetAndMetadata @@ -21,7 +23,8 @@ __all__ = [ 'KafkaAdminClient', 'KafkaConsumer', 'KafkaProducer', - 'ConsumerRebalanceListener', 'Serializer', 'Deserializer', + 'AsyncConsumerRebalanceListener', 'ConsumerRebalanceListener', + 'Serializer', 'Deserializer', 'TopicPartition', 'TopicPartitionReplica', 'OffsetAndMetadata', 'IsolationLevel', 'OffsetSpec', ] diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py index ab491d24b..b5c5344c5 100644 --- a/kafka/consumer/subscription_state.py +++ b/kafka/consumer/subscription_state.py @@ -134,8 +134,10 @@ def subscribe(self, topics=(), pattern=None, listener=None): self._set_subscription_type(SubscriptionType.AUTO_TOPICS) self.change_subscription(topics) - if listener and not isinstance(listener, ConsumerRebalanceListener): - raise TypeError('listener must be a ConsumerRebalanceListener') + if listener and not isinstance( + listener, (ConsumerRebalanceListener, AsyncConsumerRebalanceListener)): + raise TypeError( + 'listener must be a ConsumerRebalanceListener or AsyncConsumerRebalanceListener') self.rebalance_listener = listener @synchronized @@ -528,8 +530,15 @@ class ConsumerRebalanceListener(metaclass=abc.ABCMeta): it may want to automatically trigger a flush of this cache, before the new owner takes over consumption. - This callback will execute in the user thread as part of the Consumer.poll() - whenever partition assignment changes. + Threading: callbacks run on the consumer's IO event loop, the same loop + that drives heartbeats. Sync listener methods must return promptly -- + blocking IO inside a sync listener will block heartbeats for the duration + and can cause the consumer to be kicked from the group if the listener + runs longer than ``session_timeout_ms``. For listeners that need to do + blocking work (e.g. flushing state to a database), prefer + :class:`AsyncConsumerRebalanceListener`, which lets you ``await`` while + keeping the loop responsive, or wrap the blocking call in your own + worker thread. It is guaranteed that all consumer processes will invoke on_partitions_revoked() prior to any process invoking @@ -574,3 +583,48 @@ def on_partitions_assigned(self, assigned): consumer (may include partitions that were previously assigned) """ pass + + +class AsyncConsumerRebalanceListener(metaclass=abc.ABCMeta): + """ + Async variant of :class:`ConsumerRebalanceListener`. + + Implement this when your rebalance hooks need to perform IO that would + otherwise block the consumer's event loop -- e.g. flushing state to a + database, calling an external service, or coordinating with other async + code. The coordinator detects coroutine functions and ``await`` s them + instead of calling inline, so other tasks on the loop (notably the + heartbeat coroutine) continue to run while your listener is suspended. + + Same lifecycle and ordering guarantees as the sync listener: all + consumers in the group invoke ``on_partitions_revoked`` before any + invokes ``on_partitions_assigned``. Both methods must be defined as + ``async def``; otherwise use :class:`ConsumerRebalanceListener`. + """ + @abc.abstractmethod + async def on_partitions_revoked(self, revoked): + """Async-callback for the start of a rebalance operation. + + See :meth:`ConsumerRebalanceListener.on_partitions_revoked` for + semantics. The coordinator awaits this method, so non-blocking IO + via ``await`` keeps the heartbeat loop responsive during the call. + + Arguments: + revoked (set of TopicPartition): the partitions that were + assigned to the consumer on the last rebalance. + """ + pass + + @abc.abstractmethod + async def on_partitions_assigned(self, assigned): + """Async-callback for the completion of a partition re-assignment. + + See :meth:`ConsumerRebalanceListener.on_partitions_assigned` for + semantics. + + Arguments: + assigned (set of TopicPartition): the partitions assigned to + the consumer (may include partitions that were previously + assigned). + """ + pass diff --git a/kafka/coordinator/base.py b/kafka/coordinator/base.py index 98690b0a0..4544173b6 100644 --- a/kafka/coordinator/base.py +++ b/kafka/coordinator/base.py @@ -150,9 +150,14 @@ def __init__(self, client, **configs): self.rejoin_needed = True self.rejoining = False # renamed / complement of java needsJoinPrepare self.state = MemberState.UNJOINED - self.join_future = None self.coordinator_id = None self._find_coordinator_future = None + # In-flight JoinGroup -> SyncGroup task cached across poll re-entries. + # consumer.poll(timeout_ms=N) may give up while a JoinGroup is still + # pending on the broker (e.g. broker waiting for other members to + # rejoin); the next poll re-awaits this task instead of sending a + # duplicate JoinGroup. Cleared on success or non-retriable failure. + self._join_task = None self._generation = Generation.NO_GENERATION if self.config['metrics']: self._sensors = GroupCoordinatorMetrics(self.heartbeat, self.config['metrics'], @@ -197,12 +202,14 @@ def group_protocols(self): """ pass - @abc.abstractmethod - def _on_join_prepare(self, generation, member_id, timeout_ms=None): + async def _on_join_prepare_async(self, generation, member_id, timeout_ms=None): """Invoked prior to each group join or rejoin. - This is typically used to perform any cleanup from the previous - generation (such as committing offsets for the consumer) + Subclasses (e.g. :class:`ConsumerCoordinator`) override with auto-commit + + rebalance-listener invocation. Called from the join coroutine on the + event loop, so blocking work in subclass overrides will block the loop + -- including heartbeats. Async rebalance listeners are awaited; sync + listeners run inline. Arguments: generation (int): The previous generation or -1 if there was none @@ -232,11 +239,12 @@ def _perform_assignment(self, leader_id, protocol, members): """ pass - @abc.abstractmethod - def _on_join_complete(self, generation, member_id, protocol, - member_assignment_bytes): + async def _on_join_complete_async(self, generation, member_id, protocol, + member_assignment_bytes): """Invoked when a group member has successfully joined a group. + Subclasses override to apply the assignment and run user listeners. + Arguments: generation (int): the generation that was joined member_id (str): the identifier for the local member in the group @@ -388,58 +396,14 @@ def time_to_next_heartbeat(self): return float('inf') return self.heartbeat.time_to_next_heartbeat() - def _reset_join_group_future(self): - with self._lock: - self.join_future = None - - def _initiate_join_group(self): - with self._lock: - # we store the join future in case we are woken up by the user - # after beginning the rebalance in the call to poll below. - # This ensures that we do not mistakenly attempt to rejoin - # before the pending rebalance has completed. - if self.join_future is None: - log.debug("_initiate_join_group: creating new join_future (state=%s)", self.state) - self.state = MemberState.REBALANCING - self.join_future = self._send_join_group_request() - - # handle join completion in the callback so that the - # callback will be invoked even if the consumer is woken up - # before finishing the rebalance - self.join_future.add_callback(self._handle_join_success) - - # we handle failures below after the request finishes. - # If the join completes after having been woken up, the - # exception is ignored and we will rejoin - self.join_future.add_errback(self._handle_join_failure) - else: - log.debug("_initiate_join_group: returning existing join_future (is_done=%s, exception=%s, state=%s)", - self.join_future.is_done, self.join_future.exception, self.state) - - return self.join_future - - def _handle_join_success(self, member_assignment_bytes): - # handle join completion in the callback so that the callback - # will be invoked even if the consumer is woken up before - # finishing the rebalance - with self._lock: - self.state = MemberState.STABLE - self._enable_heartbeat() - - def _handle_join_failure(self, exception): - # we handle failures below after the request finishes. - # if the join completes after having been woken up, - # the exception is ignored and we will rejoin - with self._lock: - log.info("Failed to join group %s: %s", self.group_id, exception) - self.state = MemberState.UNJOINED - @property def _use_group_apis(self): return self.config['api_version'] >= (0, 9) def ensure_active_group(self, timeout_ms=None): - """Ensure that the group is active (i.e. joined and synced) + """Ensure that the group is active (i.e. joined and synced). + + Sync facade over :meth:`ensure_active_group_async`. Keyword Arguments: timeout_ms (numeric, optional): Maximum number of milliseconds to @@ -447,135 +411,103 @@ def ensure_active_group(self, timeout_ms=None): Returns: True if group initialized before timeout_ms, else False """ + with self._client._lock: + return self._manager.run(self.ensure_active_group_async, timeout_ms) + + async def ensure_active_group_async(self, timeout_ms=None): + """Async variant of :meth:`ensure_active_group`.""" if not self._use_group_apis: raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') timer = Timer(timeout_ms) - if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): + if not await self.ensure_coordinator_ready_async(timeout_ms=timer.timeout_ms): return False - # If either loop or thread died w/ exception we restart them here self._maybe_start_heartbeat_loop() self._maybe_start_heartbeat_thread() - return self.join_group(timeout_ms=timer.timeout_ms) + return await self.join_group_async(timeout_ms=timer.timeout_ms) + + async def join_group_async(self, timeout_ms=None): + """Drive JoinGroup -> SyncGroup attempts until joined or aborted. - def join_group(self, timeout_ms=None): + Internal: the only entry point is :meth:`ensure_active_group_async` + (and its sync facade :meth:`ensure_active_group`). + + Returns True when the member has been (re-)joined, False on timer + expiry, or raises on a non-retriable error. + """ if not self._use_group_apis: raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') timer = Timer(timeout_ms) while self.need_rejoin(): - if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): + if not await self.ensure_coordinator_ready_async(timeout_ms=timer.timeout_ms): return False - # call on_join_prepare if needed. We set a flag - # to make sure that we do not call it a second - # time if the client is woken up before a pending - # rebalance completes. This must be called on each - # iteration of the loop because an event requiring - # a rebalance (such as a metadata refresh which - # changes the matched subscription set) can occur - # while another rebalance is still in progress. - if not self.rejoining: - self._on_join_prepare(self._generation.generation_id, - self._generation.member_id, - timeout_ms=timer.timeout_ms) - self.rejoining = True - - # fence off the heartbeat thread explicitly so that it cannot - # interfere with the join group. # Note that this must come after - # the call to onJoinPrepare since we must be able to continue - # sending heartbeats if that callback takes some time. - log.debug("Disabling heartbeat during join-group") - self._disable_heartbeat() + # Schedule the join attempt as a Task on first entry; subsequent + # poll iterations re-await the same Task while the broker is still + # processing JoinGroup. Without this cache, a short + # consumer.poll(timeout_ms=N) that gives up on the first iteration + # would send a fresh JoinGroup on the next iteration, confusing + # the broker. + if self._join_task is None or self._join_task.is_done: + # Call _on_join_prepare once per rebalance attempt. The rejoining + # flag survives across loop iterations so we don't re-run user + # listeners or auto-commit on retry. + if not self.rejoining: + await self._on_join_prepare_async( + self._generation.generation_id, + self._generation.member_id, + timeout_ms=timer.timeout_ms) + self.rejoining = True + + # Disable heartbeat for the wire round-trip. Must come AFTER + # _on_join_prepare_async so heartbeats keep flowing while a + # potentially-slow rebalance listener runs. + log.debug("Disabling heartbeat during join-group") + self._disable_heartbeat() + + self._join_task = self._manager.call_soon(self._do_join_and_sync_async) - # ensure that there are no pending requests to the coordinator. - # This is important in particular to avoid resending a pending - # JoinGroup request. - while not self.coordinator_unknown(): - if not self._client.in_flight_request_count(self.coordinator_id): - break - poll_timeout_ms = 200 if timer.timeout_ms is None or timer.timeout_ms > 200 else timer.timeout_ms - self._client.poll(timeout_ms=poll_timeout_ms) + try: + assignment_bytes = await self._manager.wait_for( + self._join_task, timer.timeout_ms) + except Errors.KafkaTimeoutError: + # Timer expired; leave self._join_task in flight so the next + # poll re-awaits it instead of sending a duplicate JoinGroup. + return False + except (Errors.UnknownMemberIdError, + Errors.RebalanceInProgressError, + Errors.IllegalGenerationError, + Errors.MemberIdRequiredError): + # Side effects (reset_generation / coordinator_dead / + # request_rejoin) were applied by the response processors; + # loop back and retry immediately. + self._join_task = None + continue + except Errors.KafkaError as exc: + self._join_task = None + if not getattr(exc, 'retriable', False): + raise if timer.expired: return False - else: + backoff_ms = self.config['retry_backoff_ms'] + if timer.timeout_ms is not None: + backoff_ms = min(backoff_ms, timer.timeout_ms) + if backoff_ms > 0: + await self._manager._net.sleep(backoff_ms / 1000) continue + self._join_task = None - future = self._initiate_join_group() - self._client.poll(future=future, timeout_ms=timer.timeout_ms) - log.debug("join_group: after poll, future.is_done=%s future.exception=%s future is self.join_future=%s state=%s", - future.is_done, future.exception, future is self.join_future, self.state) - if future.is_done: - self._reset_join_group_future() - else: - return False - - log.debug("join_group: checking future.succeeded()=%s (is_done=%s exception=%s)", - future.succeeded(), future.is_done, future.exception) - if future.succeeded(): + with self._lock: self.rejoining = False self.rejoin_needed = False - log.debug("join_group: about to call _on_join_complete (generation=%s)", self._generation) - self._on_join_complete(self._generation.generation_id, - self._generation.member_id, - self._generation.protocol, - future.value) - log.debug("join_group: _on_join_complete returned") - return True - else: - exception = future.exception - if isinstance(exception, (Errors.UnknownMemberIdError, - Errors.RebalanceInProgressError, - Errors.IllegalGenerationError, - Errors.MemberIdRequiredError)): - continue - elif not future.retriable(): - raise exception # pylint: disable-msg=raising-bad-type - elif timer.expired: - return False - else: - if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) - - def _send_join_group_request(self): - """Join the group and return the assignment for the next generation. - - This function handles both JoinGroup and SyncGroup, delegating to - :meth:`._perform_assignment` if elected leader by the coordinator. - - Returns: - Future: resolves to the encoded-bytes assignment returned from the - group leader - """ - if self.coordinator_unknown(): - e = Errors.CoordinatorNotAvailableError(self.coordinator_id) - return Future().failure(e) - - elif not self._client.ready(self.coordinator_id, metadata_priority=False): - e = Errors.NodeNotReadyError(self.coordinator_id) - return Future().failure(e) - - # send a join group request to the coordinator - log.info("(Re-)joining group %s", self.group_id) - max_version = 6 - request = JoinGroupRequest( - group_id=self.group_id, - session_timeout_ms=self.config['session_timeout_ms'], - rebalance_timeout_ms=self.config['max_poll_interval_ms'], - member_id=self._generation.member_id, - group_instance_id=self.group_instance_id, - protocol_type=self.protocol_type(), - protocols=self.group_protocols(), - max_version=max_version) - - # create the request for the coordinator - log.debug("Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id) - future = Future() - _f = self._manager.send(request, node_id=self.coordinator_id) - _f.add_callback(self._handle_join_group_response, future, time.monotonic()) - _f.add_errback(self._failed_request, self.coordinator_id, - request, future) - return future + self.state = MemberState.STABLE + self._enable_heartbeat() + await self._on_join_complete_async( + self._generation.generation_id, + self._generation.member_id, + self._generation.protocol, + assignment_bytes) + return True + return True def _failed_request(self, node_id, request, future, error): # Marking coordinator dead @@ -592,7 +524,18 @@ def _failed_request(self, node_id, request, future, error): if future is not None: future.failure(error) - def _handle_join_group_response(self, future, send_time, response): + def _process_join_group_response(self, response, send_time): + """Classify a JoinGroupResponse: mutate state on success, raise on error. + + Used by :meth:`_do_join_and_sync_async`. Callers route to leader or + follower based on the returned response. + + Returns: + JoinGroupResponse: the response (caller does leader/follower routing). + Raises: + Errors.KafkaError: subclass matching the response error code. + UnjoinedGroupException: state is no longer REBALANCING. + """ log.debug("Received JoinGroup response: %s", response) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: @@ -600,173 +543,186 @@ def _handle_join_group_response(self, future, send_time, response): self._sensors.join_latency.record((time.monotonic() - send_time) * 1000) with self._lock: if self.state is not MemberState.REBALANCING: - # if the consumer was woken up before a rebalance completes, - # we may have already left the group. In this case, we do - # not want to continue with the sync group. - future.failure(UnjoinedGroupException()) - else: - self._generation = Generation(response.generation_id, - response.member_id, - response.protocol_name) - - log.info("Successfully joined group %s %s", self.group_id, self._generation) - if response.leader == response.member_id: - log.info("Elected group leader -- performing partition" - " assignments using %s", self._generation.protocol) - self._on_join_leader(response).chain(future) - else: - self._on_join_follower().chain(future) - - elif error_type is Errors.CoordinatorLoadInProgressError: + raise UnjoinedGroupException() + self._generation = Generation(response.generation_id, + response.member_id, + response.protocol_name) + log.info("Successfully joined group %s %s", self.group_id, self._generation) + return response + + if error_type is Errors.CoordinatorLoadInProgressError: log.info("Attempt to join group %s rejected since coordinator %s" " is loading the group.", self.group_id, self.coordinator_id) - # backoff and retry - future.failure(error_type(response)) - elif error_type is Errors.UnknownMemberIdError: - # reset the member id and retry immediately + raise error_type(response) + + if error_type is Errors.UnknownMemberIdError: error = error_type(self._generation.member_id) self.reset_generation() log.info("Attempt to join group %s failed due to unknown member id", self.group_id) - future.failure(error) - elif error_type in (Errors.CoordinatorNotAvailableError, - Errors.NotCoordinatorError): - # re-discover the coordinator and retry with backoff + raise error + + if error_type in (Errors.CoordinatorNotAvailableError, + Errors.NotCoordinatorError): self.coordinator_dead(error_type()) log.info("Attempt to join group %s failed due to obsolete " "coordinator information: %s", self.group_id, error_type.__name__) - future.failure(error_type()) - elif error_type in (Errors.InconsistentGroupProtocolError, - Errors.InvalidSessionTimeoutError, - Errors.InvalidGroupIdError, - Errors.GroupAuthorizationFailedError, - Errors.GroupMaxSizeReachedError, - Errors.FencedInstanceIdError): - # log the error and re-throw the exception + raise error_type() + + if error_type in (Errors.InconsistentGroupProtocolError, + Errors.InvalidSessionTimeoutError, + Errors.InvalidGroupIdError, + Errors.GroupAuthorizationFailedError, + Errors.GroupMaxSizeReachedError, + Errors.FencedInstanceIdError): log.error("Attempt to join group %s failed due to fatal error: %s", self.group_id, error_type.__name__) - if error_type in (Errors.GroupAuthorizationFailedError, Errors.GroupMaxSizeReachedError): - future.failure(error_type(self.group_id)) - else: - future.failure(error_type()) - elif error_type is Errors.MemberIdRequiredError: - # Broker requires a concrete member id to be allowed to join the group. Update member id - # and send another join group request in next cycle. + if error_type in (Errors.GroupAuthorizationFailedError, + Errors.GroupMaxSizeReachedError): + raise error_type(self.group_id) + raise error_type() + + if error_type is Errors.MemberIdRequiredError: log.info("Received member id %s for group %s; will retry join-group", response.member_id, self.group_id) self.reset_generation(response.member_id) - future.failure(error_type()) - elif error_type is Errors.RebalanceInProgressError: + raise error_type() + + if error_type is Errors.RebalanceInProgressError: log.info("Attempt to join group %s failed due to RebalanceInProgressError," " which could indicate a replication timeout on the broker. Will retry.", self.group_id) - future.failure(error_type()) - else: - # unexpected error, throw the exception - error = error_type() - log.error("Unexpected error in join group response: %s", error) - future.failure(error) + raise error_type() - def _on_join_follower(self): - # send follower's sync group with an empty assignment - max_version = 4 - request = SyncGroupRequest( - group_id=self.group_id, - generation_id=self._generation.generation_id, - member_id=self._generation.member_id, - group_instance_id=self.group_instance_id, - assignments=[], - max_version=max_version) - log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s", - self.group_id, self.coordinator_id, request) - return self._send_sync_group_request(request) + error = error_type() + log.error("Unexpected error in join group response: %s", error) + raise error - def _on_join_leader(self, response): - """ - Perform leader synchronization and send back the assignment - for the group via SyncGroupRequest + def _process_sync_group_response(self, response, send_time): + """Classify a SyncGroupResponse: return assignment bytes or raise. - Arguments: - response (JoinResponse): broker response to parse + Used by :meth:`_do_join_and_sync_async`. Applies ``request_rejoin()`` + / ``coordinator_dead()`` / ``reset_generation()`` side effects on + the relevant error codes. Returns: - Future: resolves to member assignment encoded-bytes + bytes: encoded member assignment. + Raises: + Errors.KafkaError: subclass matching the response error code. """ - try: - group_assignment = self._perform_assignment(response.leader, - response.protocol_name, - response.members) - except Exception as e: - return Future().failure(e) - - max_version = 4 - request = SyncGroupRequest( - group_id=self.group_id, - generation_id=self._generation.generation_id, - member_id=self._generation.member_id, - group_instance_id=self.group_instance_id, - assignments=group_assignment.items(), - max_version=max_version) - log.debug("Sending leader SyncGroup for group %s to coordinator %s: %s", - self.group_id, self.coordinator_id, request) - return self._send_sync_group_request(request) - - def _send_sync_group_request(self, request): - if self.coordinator_unknown(): - e = Errors.CoordinatorNotAvailableError(self.coordinator_id) - return Future().failure(e) - - # We assume that coordinator is ready if we're sending SyncGroup - # as it typically follows a successful JoinGroup - # Also note that if client.ready() enforces a metadata priority policy, - # we can get into an infinite loop if the leader assignment process - # itself requests a metadata update - - future = Future() - _f = self._manager.send(request, node_id=self.coordinator_id) - _f.add_callback(self._handle_sync_group_response, future, time.monotonic()) - _f.add_errback(self._failed_request, self.coordinator_id, - request, future) - return future - - def _handle_sync_group_response(self, future, send_time, response): log.debug("Received SyncGroup response: %s", response) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: if self._sensors: self._sensors.sync_latency.record((time.monotonic() - send_time) * 1000) - future.success(response.assignment) - return + return response.assignment # Always rejoin on error self.request_rejoin() if error_type is Errors.GroupAuthorizationFailedError: - future.failure(error_type(self.group_id)) - elif error_type is Errors.RebalanceInProgressError: - log.info("SyncGroup for group %s failed due to coordinator" - " rebalance", self.group_id) - future.failure(error_type(self.group_id)) - elif error_type is Errors.FencedInstanceIdError: + raise error_type(self.group_id) + if error_type is Errors.RebalanceInProgressError: + log.info("SyncGroup for group %s failed due to coordinator rebalance", + self.group_id) + raise error_type(self.group_id) + if error_type is Errors.FencedInstanceIdError: log.error("SyncGroup for group %s failed due to fenced id error: %s", self.group_id, self.group_instance_id) - future.failure(error_type((self.group_id, self.group_instance_id))) - elif error_type in (Errors.UnknownMemberIdError, - Errors.IllegalGenerationError): + raise error_type((self.group_id, self.group_instance_id)) + if error_type in (Errors.UnknownMemberIdError, Errors.IllegalGenerationError): error = error_type() log.info("SyncGroup for group %s failed due to %s", self.group_id, error) self.reset_generation() - future.failure(error) - elif error_type in (Errors.CoordinatorNotAvailableError, - Errors.NotCoordinatorError): + raise error + if error_type in (Errors.CoordinatorNotAvailableError, + Errors.NotCoordinatorError): error = error_type() log.info("SyncGroup for group %s failed due to %s", self.group_id, error) self.coordinator_dead(error) - future.failure(error) + raise error + error = error_type() + log.error("Unexpected error from SyncGroup: %s", error) + raise error + + async def _do_join_and_sync_async(self): + """Run a single JoinGroup -> SyncGroup attempt against the coordinator. + + Sends a JoinGroupRequest and processes the response (mutates + self._generation on success). Then dispatches as group leader + (running the configured assignor) or follower (empty assignment), + sends the matching SyncGroupRequest, and returns the assignment + bytes from the response. + + The outer retry loop in :meth:`join_group_async` handles backoff + and retriable errors; this method attempts exactly one round trip. + + Returns: + bytes: the encoded member assignment from SyncGroupResponse. + + Raises: + Errors.CoordinatorNotAvailableError: if the coordinator is unknown. + Errors.KafkaError: on any error response from JoinGroup or + SyncGroup. Side effects (coordinator_dead, reset_generation, + request_rejoin) are applied by the response processors. + Exception: anything raised by ``_perform_assignment`` + (e.g. assignor crash); leader-only path. + """ + if self.coordinator_unknown(): + raise Errors.CoordinatorNotAvailableError(self.coordinator_id) + + with self._lock: + self.state = MemberState.REBALANCING + + log.info("(Re-)joining group %s", self.group_id) + join_request = JoinGroupRequest( + group_id=self.group_id, + session_timeout_ms=self.config['session_timeout_ms'], + rebalance_timeout_ms=self.config['max_poll_interval_ms'], + member_id=self._generation.member_id, + group_instance_id=self.group_instance_id, + protocol_type=self.protocol_type(), + protocols=self.group_protocols(), + max_version=6) + log.debug("Sending JoinGroup (%s) to coordinator %s", + join_request, self.coordinator_id) + join_send_time = time.monotonic() + join_response = await self._manager.send( + join_request, node_id=self.coordinator_id) + # raises on error; mutates self._generation on success + self._process_join_group_response(join_response, join_send_time) + + if join_response.leader == join_response.member_id: + log.info("Elected group leader -- performing partition assignments" + " using %s", self._generation.protocol) + group_assignment = self._perform_assignment( + join_response.leader, + join_response.protocol_name, + join_response.members) + sync_request = SyncGroupRequest( + group_id=self.group_id, + generation_id=self._generation.generation_id, + member_id=self._generation.member_id, + group_instance_id=self.group_instance_id, + assignments=group_assignment.items(), + max_version=4) + log.debug("Sending leader SyncGroup for group %s to coordinator %s: %s", + self.group_id, self.coordinator_id, sync_request) else: - error = error_type() - log.error("Unexpected error from SyncGroup: %s", error) - future.failure(error) + sync_request = SyncGroupRequest( + group_id=self.group_id, + generation_id=self._generation.generation_id, + member_id=self._generation.member_id, + group_instance_id=self.group_instance_id, + assignments=[], + max_version=4) + log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s", + self.group_id, self.coordinator_id, sync_request) + + sync_send_time = time.monotonic() + sync_response = await self._manager.send( + sync_request, node_id=self.coordinator_id) + return self._process_sync_group_response(sync_response, sync_send_time) def _send_group_coordinator_request(self): """Discover the current coordinator for the group. diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index 1b00a2492..cec05faac 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -1,6 +1,7 @@ import collections import copy import functools +import inspect import logging import time @@ -223,8 +224,38 @@ def _build_metadata_snapshot(self, subscription, cluster): def _lookup_assignor(self, name): return self._assignors.get(name, None) - def _on_join_complete(self, generation, member_id, protocol, - member_assignment_bytes): + # Threshold above which a rebalance-listener invocation is logged as a + # warning. Sync listeners on the IO loop will block heartbeats while + # they run; even async ones delay rebalance progress. 1s is a soft + # ceiling: well below default heartbeat_interval_ms (3s) and + # session_timeout_ms (10s). + _REBALANCE_LISTENER_WARN_SECS = 1.0 + + async def _invoke_rebalance_listener_async(self, method_name, arg): + """Invoke a rebalance-listener method (sync or async), timing the call. + + Awaits if the method is a coroutine function; otherwise calls inline. + Logs a warning if the call exceeds + :data:`_REBALANCE_LISTENER_WARN_SECS`. Caller wraps in try/except. + """ + cb = getattr(self._subscription.rebalance_listener, method_name) + start = time.monotonic() + if inspect.iscoroutinefunction(cb): + await cb(arg) + else: + cb(arg) + elapsed = time.monotonic() - start + if elapsed > self._REBALANCE_LISTENER_WARN_SECS: + log.warning( + "Rebalance listener %s.%s for group %s took %.3fs." + " Sync listeners block the consumer event loop (including" + " heartbeats) -- consider AsyncConsumerRebalanceListener or" + " wrap blocking work in a worker thread.", + type(self._subscription.rebalance_listener).__name__, + method_name, self.group_id, elapsed) + + async def _on_join_complete_async(self, generation, member_id, protocol, + member_assignment_bytes): # only the leader is responsible for monitoring for metadata changes # (i.e. partition changes) if not self._is_leader: @@ -255,7 +286,8 @@ def _on_join_complete(self, generation, member_id, protocol, # execute the user's callback after rebalance if self._subscription.rebalance_listener: try: - self._subscription.rebalance_listener.on_partitions_assigned(assigned) + await self._invoke_rebalance_listener_async( + 'on_partitions_assigned', assigned) except Exception: log.exception("User provided rebalance listener %s for group %s" " failed on partition assignment: %s", @@ -358,9 +390,22 @@ def _perform_assignment(self, leader_id, protocol_name, members): group_assignment[member_id] = assignment return group_assignment - def _on_join_prepare(self, generation, member_id, timeout_ms=None): + async def _on_join_prepare_async(self, generation, member_id, timeout_ms=None): # commit offsets prior to rebalance if auto-commit enabled - self._maybe_auto_commit_offsets_sync(timeout_ms=timeout_ms) + if self.config['enable_auto_commit']: + try: + await self._commit_offsets_sync_async( + self._subscription.all_consumed_offsets(), + timeout_ms=timeout_ms) + except (Errors.UnknownMemberIdError, + Errors.IllegalGenerationError, + Errors.RebalanceInProgressError): + log.warning("Pre-rebalance offset commit failed: group membership" + " out of date. This is likely to cause duplicate" + " message delivery.") + except Exception: + log.exception("Pre-rebalance offset commit failed: This is likely" + " to cause duplicate message delivery") # execute the user's callback before rebalance log.info("Revoking previously assigned partitions %s for group %s", @@ -368,7 +413,8 @@ def _on_join_prepare(self, generation, member_id, timeout_ms=None): if self._subscription.rebalance_listener: try: revoked = set(self._subscription.assigned_partitions()) - self._subscription.rebalance_listener.on_partitions_revoked(revoked) + await self._invoke_rebalance_listener_async( + 'on_partitions_revoked', revoked) except Exception: log.exception("User provided subscription rebalance listener %s" " for group %s failed on_partitions_revoked", @@ -573,12 +619,12 @@ def commit_offsets_sync(self, offsets, timeout_ms=None): assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) self._invoke_completed_offset_commit_callbacks() - if not offsets: - return with self._client._lock: return self._manager.run(self._commit_offsets_sync_async, offsets, timeout_ms) async def _commit_offsets_sync_async(self, offsets, timeout_ms=None): + if not offsets: + return timer = Timer(timeout_ms) while True: await self.ensure_coordinator_ready_async(timeout_ms=timer.timeout_ms) diff --git a/test/consumer/test_coordinator.py b/test/consumer/test_coordinator.py index 62f2fe390..1fbd7c15c 100644 --- a/test/consumer/test_coordinator.py +++ b/test/consumer/test_coordinator.py @@ -3,7 +3,11 @@ import pytest -from kafka.consumer.subscription_state import SubscriptionState, ConsumerRebalanceListener +from kafka.consumer.subscription_state import ( + AsyncConsumerRebalanceListener, + ConsumerRebalanceListener, + SubscriptionState, +) from kafka.coordinator.assignors.abstract import ( ConsumerProtocolSubscription, ConsumerProtocolAssignment, ) @@ -14,10 +18,12 @@ from kafka.coordinator.consumer import ConsumerCoordinator import kafka.errors as Errors from kafka.future import Future +from kafka.coordinator.base import UnjoinedGroupException from kafka.protocol.consumer import ( OffsetCommitRequest, OffsetCommitResponse, OffsetFetchRequest, OffsetFetchResponse, - JoinGroupResponse, + JoinGroupRequest, JoinGroupResponse, + SyncGroupRequest, SyncGroupResponse, ) from kafka.protocol.metadata import MetadataResponse from kafka.structs import OffsetAndMetadata, TopicPartition @@ -125,7 +131,9 @@ def test_join_complete(mocker, coordinator): assert assignor.on_assignment.call_count == 0 assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') generation = 12 - coordinator._on_join_complete(generation, 'member-foo', 'roundrobin', assignment.encode()) + coordinator._manager.run( + coordinator._on_join_complete_async, + generation, 'member-foo', 'roundrobin', assignment.encode()) assert assignor.on_assignment.call_count == 1 assignor.on_assignment.assert_called_with(assignment, generation) @@ -138,45 +146,13 @@ def test_join_complete_with_sticky_assignor(mocker, coordinator): assert assignor.on_assignment.call_count == 0 generation = 3 assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete(generation, 'member-foo', 'sticky', assignment.encode()) + coordinator._manager.run( + coordinator._on_join_complete_async, + generation, 'member-foo', 'sticky', assignment.encode()) assert assignor.on_assignment.call_count == 1 assignor.on_assignment.assert_called_with(assignment, generation) -def test_subscription_listener(mocker, coordinator): - listener = mocker.MagicMock(spec=ConsumerRebalanceListener) - coordinator._subscription.subscribe( - topics=['foobar'], - listener=listener) - - coordinator._on_join_prepare(0, 'member-foo') - assert listener.on_partitions_revoked.call_count == 1 - listener.on_partitions_revoked.assert_called_with(set([])) - - assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete( - 0, 'member-foo', 'roundrobin', assignment.encode()) - assert listener.on_partitions_assigned.call_count == 1 - listener.on_partitions_assigned.assert_called_with({TopicPartition('foobar', 0), TopicPartition('foobar', 1)}) - - -def test_subscription_listener_failure(mocker, coordinator): - listener = mocker.MagicMock(spec=ConsumerRebalanceListener) - coordinator._subscription.subscribe( - topics=['foobar'], - listener=listener) - - # exception raised in listener should not be re-raised by coordinator - listener.on_partitions_revoked.side_effect = Exception('crash') - coordinator._on_join_prepare(0, 'member-foo') - assert listener.on_partitions_revoked.call_count == 1 - - assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete( - 0, 'member-foo', 'roundrobin', assignment.encode()) - assert listener.on_partitions_assigned.call_count == 1 - - def test_perform_assignment(mocker, coordinator): coordinator._subscription.subscribe(topics=['foo1']) members = [ @@ -209,9 +185,159 @@ def test_perform_assignment(mocker, coordinator): assert ret == assignments -def test_on_join_prepare(coordinator): +def test_on_join_prepare_async_invokes_sync_listener(mocker, coordinator): + coordinator.config['enable_auto_commit'] = False + listener = mocker.MagicMock(spec=ConsumerRebalanceListener) + coordinator._subscription.subscribe(topics=['foobar'], listener=listener) + + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + + assert listener.on_partitions_revoked.call_count == 1 + listener.on_partitions_revoked.assert_called_with(set()) + + +def test_on_join_prepare_async_awaits_async_listener(coordinator): + """An AsyncConsumerRebalanceListener subclass is accepted and awaited.""" + coordinator.config['enable_auto_commit'] = False + calls = [] + + class MyListener(AsyncConsumerRebalanceListener): + async def on_partitions_revoked(self, revoked): + calls.append(('revoked', revoked)) + async def on_partitions_assigned(self, assigned): + calls.append(('assigned', assigned)) + + coordinator._subscription.subscribe(topics=['foobar'], listener=MyListener()) + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + + assert calls == [('revoked', set())] + + +def test_subscribe_rejects_non_listener(coordinator): + """Anything that isn't a (Async)ConsumerRebalanceListener is rejected.""" + with pytest.raises(TypeError): + coordinator._subscription.subscribe( + topics=['foobar'], listener=lambda revoked: None) + + +def test_slow_rebalance_listener_logs_warning(mocker, coordinator): + """A listener call exceeding the threshold logs a named warning.""" + coordinator.config['enable_auto_commit'] = False + + class SlowListener(ConsumerRebalanceListener): + def on_partitions_revoked(self, revoked): + time.sleep(0.01) # well under the threshold; shouldn't warn + def on_partitions_assigned(self, assigned): + pass + + coordinator._subscription.subscribe(topics=['foobar'], listener=SlowListener()) + log_warning = mocker.patch('kafka.coordinator.consumer.log.warning') + + # Below threshold: no warning. + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + assert not any( + 'Rebalance listener' in str(call.args[0]) + for call in log_warning.call_args_list) + + # Above threshold: drop the threshold to a tiny value and re-run. + mocker.patch.object(coordinator, '_REBALANCE_LISTENER_WARN_SECS', 0.001) + log_warning.reset_mock() + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + matching = [c for c in log_warning.call_args_list + if 'Rebalance listener' in str(c.args[0])] + assert len(matching) == 1 + # log.warning(fmt, listener_class, method, group, elapsed) + _fmt, listener_class, method, _group, _elapsed = matching[0].args + assert listener_class == 'SlowListener' + assert method == 'on_partitions_revoked' + + +def test_on_join_prepare_async_listener_exception_is_caught(mocker, coordinator): + coordinator.config['enable_auto_commit'] = False + listener = mocker.MagicMock(spec=ConsumerRebalanceListener) + listener.on_partitions_revoked.side_effect = RuntimeError('listener crash') + coordinator._subscription.subscribe(topics=['foobar'], listener=listener) + + # Should not raise; should still complete the post-listener cleanup. + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + assert coordinator._is_leader is False + + +def test_on_join_prepare_async_skips_auto_commit_when_disabled(mocker, coordinator): + coordinator.config['enable_auto_commit'] = False + spy = mocker.spy(coordinator, '_commit_offsets_sync_async') coordinator._subscription.subscribe(topics=['foobar']) - coordinator._on_join_prepare(0, 'member-foo') + + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + + assert spy.call_count == 0 + + +def test_on_join_prepare_async_runs_auto_commit_when_enabled(mocker, coordinator): + coordinator.config['enable_auto_commit'] = True + async def _noop(*args, **kwargs): + return None + spy = mocker.patch.object(coordinator, '_commit_offsets_sync_async', + side_effect=_noop) + coordinator._subscription.subscribe(topics=['foobar']) + + coordinator._manager.run(coordinator._on_join_prepare_async, 0, 'member-foo') + + assert spy.call_count == 1 + + +def test_on_join_complete_async_invokes_sync_listener(mocker, coordinator): + listener = mocker.MagicMock(spec=ConsumerRebalanceListener) + coordinator._subscription.subscribe(topics=['foobar'], listener=listener) + assignor = RoundRobinPartitionAssignor() + coordinator._assignors = {assignor.name: assignor} + assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') + + coordinator._manager.run( + coordinator._on_join_complete_async, + 12, 'member-foo', 'roundrobin', assignment.encode()) + + assert listener.on_partitions_assigned.call_count == 1 + listener.on_partitions_assigned.assert_called_with( + {TopicPartition('foobar', 0), TopicPartition('foobar', 1)}) + + +def test_on_join_complete_async_awaits_async_listener(coordinator): + """An AsyncConsumerRebalanceListener subclass is accepted and awaited.""" + calls = [] + + class MyListener(AsyncConsumerRebalanceListener): + async def on_partitions_revoked(self, revoked): + calls.append(('revoked', revoked)) + async def on_partitions_assigned(self, assigned): + calls.append(('assigned', assigned)) + + coordinator._subscription.subscribe(topics=['foobar'], listener=MyListener()) + assignor = RoundRobinPartitionAssignor() + coordinator._assignors = {assignor.name: assignor} + assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') + + coordinator._manager.run( + coordinator._on_join_complete_async, + 12, 'member-foo', 'roundrobin', assignment.encode()) + + assert calls == [( + 'assigned', + {TopicPartition('foobar', 0), TopicPartition('foobar', 1)})] + + +def test_on_join_complete_async_listener_exception_is_caught(mocker, coordinator): + listener = mocker.MagicMock(spec=ConsumerRebalanceListener) + listener.on_partitions_assigned.side_effect = RuntimeError('listener crash') + coordinator._subscription.subscribe(topics=['foobar'], listener=listener) + assignor = RoundRobinPartitionAssignor() + coordinator._assignors = {assignor.name: assignor} + assignment = ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'') + + # Should not raise. + coordinator._manager.run( + coordinator._on_join_complete_async, + 12, 'member-foo', 'roundrobin', assignment.encode()) def test_need_rejoin(coordinator): @@ -669,6 +795,418 @@ def test_handle_offset_fetch_response(coordinator, offsets, response, error, dea assert coordinator.coordinator_id is (None if dead else 0) +def _join_response(error_code=0, generation_id=1, member_id='member-1', + leader='member-1', protocol_name='range', members=None): + return JoinGroupResponse( + throttle_time_ms=0, + error_code=error_code, + generation_id=generation_id, + protocol_name=protocol_name, + leader=leader, + member_id=member_id, + members=members or []) + + +@pytest.mark.parametrize('error_code,error_type,coordinator_dead,resets_member_id,resets_generation', [ + (0, None, False, False, False), + (Errors.CoordinatorLoadInProgressError.errno, Errors.CoordinatorLoadInProgressError, False, False, False), + (Errors.UnknownMemberIdError.errno, Errors.UnknownMemberIdError, False, True, True), + (Errors.CoordinatorNotAvailableError.errno, Errors.CoordinatorNotAvailableError, True, False, False), + (Errors.NotCoordinatorError.errno, Errors.NotCoordinatorError, True, False, False), + (Errors.InconsistentGroupProtocolError.errno, Errors.InconsistentGroupProtocolError, False, False, False), + (Errors.InvalidSessionTimeoutError.errno, Errors.InvalidSessionTimeoutError, False, False, False), + (Errors.InvalidGroupIdError.errno, Errors.InvalidGroupIdError, False, False, False), + (Errors.GroupAuthorizationFailedError.errno, Errors.GroupAuthorizationFailedError, False, False, False), + (Errors.GroupMaxSizeReachedError.errno, Errors.GroupMaxSizeReachedError, False, False, False), + (Errors.FencedInstanceIdError.errno, Errors.FencedInstanceIdError, False, False, False), + (Errors.MemberIdRequiredError.errno, Errors.MemberIdRequiredError, False, True, True), + (Errors.RebalanceInProgressError.errno, Errors.RebalanceInProgressError, False, False, False), + # Unmapped error code: should raise the corresponding generic error. + (Errors.UnknownError.errno, Errors.UnknownError, False, False, False), +]) +def test_process_join_group_response(request, coordinator, error_code, error_type, + coordinator_dead, resets_member_id, + resets_generation): + # Avoid LeaveGroup attempt during teardown (close() only sends LeaveGroup + # when state is not UNJOINED). + request.addfinalizer(lambda: setattr(coordinator, 'state', MemberState.UNJOINED)) + coordinator.coordinator_id = 0 + coordinator.state = MemberState.REBALANCING + coordinator._generation = Generation(7, 'old-member', 'range') + + response = _join_response( + error_code=error_code, + generation_id=42, + member_id='broker-assigned' if error_code == Errors.MemberIdRequiredError.errno else 'member-1') + + if error_type is None: + ret = coordinator._process_join_group_response(response, send_time=time.monotonic()) + assert ret is response + # State mutation: generation updated from response. + assert coordinator._generation.generation_id == 42 + assert coordinator._generation.member_id == 'member-1' + assert coordinator._generation.protocol == 'range' + assert coordinator.coordinator_id == 0 + else: + with pytest.raises(error_type): + coordinator._process_join_group_response(response, send_time=time.monotonic()) + if coordinator_dead: + assert coordinator.coordinator_id is None + else: + assert coordinator.coordinator_id == 0 + if resets_member_id: + # MemberIdRequired captures the broker-assigned id; UnknownMemberId + # clears it back to UNKNOWN_MEMBER_ID. + if error_code == Errors.MemberIdRequiredError.errno: + assert coordinator._generation.member_id == 'broker-assigned' + else: + assert coordinator._generation.member_id == '' + if resets_generation: + assert coordinator._generation.generation_id == -1 + + +def test_process_join_group_response_state_not_rebalancing(coordinator): + """Defensive: if state changed underneath us, raise UnjoinedGroupException.""" + coordinator.coordinator_id = 0 + coordinator.state = MemberState.UNJOINED + response = _join_response(error_code=0) + with pytest.raises(UnjoinedGroupException): + coordinator._process_join_group_response(response, send_time=time.monotonic()) + + +@pytest.mark.parametrize('error_code,error_type,coordinator_dead,resets_generation,requests_rejoin', [ + (0, None, False, False, False), + (Errors.GroupAuthorizationFailedError.errno, Errors.GroupAuthorizationFailedError, False, False, True), + (Errors.RebalanceInProgressError.errno, Errors.RebalanceInProgressError, False, False, True), + (Errors.FencedInstanceIdError.errno, Errors.FencedInstanceIdError, False, False, True), + (Errors.UnknownMemberIdError.errno, Errors.UnknownMemberIdError, False, True, True), + (Errors.IllegalGenerationError.errno, Errors.IllegalGenerationError, False, True, True), + (Errors.CoordinatorNotAvailableError.errno, Errors.CoordinatorNotAvailableError, True, False, True), + (Errors.NotCoordinatorError.errno, Errors.NotCoordinatorError, True, False, True), + (Errors.UnknownError.errno, Errors.UnknownError, False, False, True), +]) +def test_process_sync_group_response(request, coordinator, error_code, error_type, + coordinator_dead, resets_generation, + requests_rejoin): + request.addfinalizer(lambda: setattr(coordinator, 'state', MemberState.UNJOINED)) + coordinator.coordinator_id = 0 + coordinator._generation = Generation(7, 'member-1', 'range') + coordinator.rejoin_needed = False + + assignment_bytes = b'\x00\x01\x02' + response = SyncGroupResponse( + throttle_time_ms=0, + error_code=error_code, + assignment=assignment_bytes) + + if error_type is None: + ret = coordinator._process_sync_group_response(response, send_time=time.monotonic()) + assert ret == assignment_bytes + assert coordinator.rejoin_needed is False + assert coordinator.coordinator_id == 0 + else: + with pytest.raises(error_type): + coordinator._process_sync_group_response(response, send_time=time.monotonic()) + if requests_rejoin: + assert coordinator.rejoin_needed is True + if coordinator_dead: + assert coordinator.coordinator_id is None + if resets_generation: + assert coordinator._generation.generation_id == -1 + assert coordinator._generation.member_id == '' + + +def _join_response_object(error_code=0, generation_id=42, + member_id='member-1', leader='member-1', + protocol_name='range', members=None): + return JoinGroupResponse( + throttle_time_ms=0, + error_code=error_code, + generation_id=generation_id, + protocol_type='consumer', + protocol_name=protocol_name, + leader=leader, + member_id=member_id, + members=members or []) + + +def _sync_response_object(error_code=0, assignment=b''): + return SyncGroupResponse( + throttle_time_ms=0, + error_code=error_code, + protocol_type='consumer', + protocol_name='range', + assignment=assignment) + + +def test_do_join_and_sync_async_follower(request, broker, seeded_coord): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + # Default broker.broker_version=(4,2) → JoinGroup v9, SyncGroup v5. + # Follower: leader != our member_id. + broker.respond(JoinGroupRequest, _join_response_object( + leader='leader-x', member_id='member-1', members=[])) + expected_assignment = ConsumerProtocolAssignment( + 0, [('foobar', [0, 1])], b'').encode() + broker.respond(SyncGroupRequest, _sync_response_object( + assignment=expected_assignment)) + + result = seeded_coord._manager.run(seeded_coord._do_join_and_sync_async) + + assert result == expected_assignment + assert seeded_coord._generation.generation_id == 42 + assert seeded_coord._generation.member_id == 'member-1' + assert seeded_coord._generation.protocol == 'range' + assert seeded_coord.state == MemberState.REBALANCING + + +def test_do_join_and_sync_async_leader(request, mocker, broker, seeded_coord): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + # Leader: response.leader == response.member_id. Members include the leader. + member_metadata = ConsumerProtocolSubscription(0, ['foobar'], b'').encode() + member = JoinGroupResponse.JoinGroupResponseMember( + member_id='member-1', + group_instance_id=None, + metadata=member_metadata) + broker.respond(JoinGroupRequest, _join_response_object( + leader='member-1', member_id='member-1', members=[member])) + + # Capture the SyncGroup request to verify the leader sent assignments. + captured = {} + + def sync_handler(api_key, api_version, correlation_id, request_bytes): + captured['request'] = SyncGroupRequest.decode( + request_bytes, version=api_version, header=True) + return _sync_response_object( + assignment=ConsumerProtocolAssignment( + 0, [('foobar', [0, 1])], b'').encode()) + + broker.respond_fn(SyncGroupRequest, sync_handler) + + # Spy on _perform_assignment to confirm the leader path ran the assignor. + spy = mocker.spy(seeded_coord, '_perform_assignment') + + result = seeded_coord._manager.run(seeded_coord._do_join_and_sync_async) + + assert spy.call_count == 1 + leader_id, protocol_name, members_arg = spy.call_args[0] + assert leader_id == 'member-1' + assert protocol_name == 'range' + assert len(members_arg) == 1 + assert members_arg[0].member_id == 'member-1' + # SyncGroup carried a non-empty assignment list. + assert len(captured['request'].assignments) >= 1 + # Returned the assignment bytes the broker sent back. + assert result == ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'').encode() + + +def test_do_join_and_sync_async_coordinator_unknown(request, seeded_coord): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.coordinator_id = None # force coordinator_unknown + with pytest.raises(Errors.CoordinatorNotAvailableError): + seeded_coord._manager.run(seeded_coord._do_join_and_sync_async) + + +@pytest.mark.parametrize('error_code,error_type', [ + (Errors.CoordinatorLoadInProgressError.errno, Errors.CoordinatorLoadInProgressError), + (Errors.UnknownMemberIdError.errno, Errors.UnknownMemberIdError), + (Errors.NotCoordinatorError.errno, Errors.NotCoordinatorError), + (Errors.MemberIdRequiredError.errno, Errors.MemberIdRequiredError), + (Errors.RebalanceInProgressError.errno, Errors.RebalanceInProgressError), + (Errors.GroupAuthorizationFailedError.errno, Errors.GroupAuthorizationFailedError), +]) +def test_do_join_and_sync_async_join_error(request, broker, seeded_coord, + error_code, error_type): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + broker.respond(JoinGroupRequest, _join_response_object(error_code=error_code)) + with pytest.raises(error_type): + seeded_coord._manager.run(seeded_coord._do_join_and_sync_async) + + +@pytest.mark.parametrize('error_code,error_type', [ + (Errors.GroupAuthorizationFailedError.errno, Errors.GroupAuthorizationFailedError), + (Errors.RebalanceInProgressError.errno, Errors.RebalanceInProgressError), + (Errors.UnknownMemberIdError.errno, Errors.UnknownMemberIdError), + (Errors.IllegalGenerationError.errno, Errors.IllegalGenerationError), + (Errors.NotCoordinatorError.errno, Errors.NotCoordinatorError), +]) +def test_do_join_and_sync_async_sync_error(request, broker, seeded_coord, + error_code, error_type): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + # JoinGroup succeeds (follower) so we get to SyncGroup. + broker.respond(JoinGroupRequest, _join_response_object( + leader='leader-x', member_id='member-1')) + broker.respond(SyncGroupRequest, _sync_response_object(error_code=error_code)) + with pytest.raises(error_type): + seeded_coord._manager.run(seeded_coord._do_join_and_sync_async) + # All sync errors flip rejoin_needed via request_rejoin(). + assert seeded_coord.rejoin_needed is True + + +def test_join_group_async_no_rejoin_returns_true(request, mocker, broker, seeded_coord): + """need_rejoin() False -> short-circuits to True without any requests.""" + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + mocker.patch.object(seeded_coord, 'need_rejoin', return_value=False) + seeded_coord.state = MemberState.STABLE + + before = broker.requests_received + result = seeded_coord._manager.run(seeded_coord.join_group_async, 5000) + + assert result is True + assert broker.requests_received == before + + +def test_join_group_async_happy_path_follower(request, broker, seeded_coord): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.rejoin_needed = True + seeded_coord.state = MemberState.UNJOINED + broker.respond(JoinGroupRequest, _join_response_object( + leader='leader-x', member_id='member-1', members=[])) + broker.respond(SyncGroupRequest, _sync_response_object( + assignment=ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'').encode())) + + result = seeded_coord._manager.run(seeded_coord.join_group_async, 5000) + + assert result is True + assert seeded_coord.state == MemberState.STABLE + assert seeded_coord.rejoin_needed is False + assert seeded_coord.rejoining is False + assert seeded_coord._heartbeat_enabled is True + + +def test_join_group_async_retries_on_retriable_error(request, broker, seeded_coord): + """First JoinGroup fails with RebalanceInProgress; loop retries and succeeds.""" + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.rejoin_needed = True + seeded_coord.state = MemberState.UNJOINED + broker.respond(JoinGroupRequest, _join_response_object( + error_code=Errors.RebalanceInProgressError.errno)) + broker.respond(JoinGroupRequest, _join_response_object( + leader='leader-x', member_id='member-1', members=[])) + broker.respond(SyncGroupRequest, _sync_response_object( + assignment=ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'').encode())) + + result = seeded_coord._manager.run(seeded_coord.join_group_async, 5000) + + assert result is True + assert seeded_coord.state == MemberState.STABLE + + +def test_join_group_async_raises_non_retriable(request, broker, seeded_coord): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.rejoin_needed = True + seeded_coord.state = MemberState.UNJOINED + broker.respond(JoinGroupRequest, _join_response_object( + error_code=Errors.GroupAuthorizationFailedError.errno)) + + with pytest.raises(Errors.GroupAuthorizationFailedError): + seeded_coord._manager.run(seeded_coord.join_group_async, 5000) + + +def test_join_group_async_returns_false_on_short_timeout_and_caches_task( + request, broker, seeded_coord): + """Short consumer.poll(timeout_ms=N) should return False instead of + hanging when the broker is slow to respond to JoinGroup; the in-flight + task is cached so the next poll re-awaits it instead of sending a fresh + JoinGroup. + + Regression for the test_group integration hang where 4 consumers tearing + down concurrently left one stuck awaiting JoinGroup while the broker + waited for the others to rejoin. Both properties are necessary: + - timer must fire (else the user thread hangs and never sees stop) + - in-flight task must be cached (else next poll sends a duplicate + JoinGroup, confusing the broker's rebalance state) + """ + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.rejoin_needed = True + seeded_coord.state = MemberState.UNJOINED + + # JoinGroup response future controlled by the test. Hangs until released. + join_response_pending = Future() + join_request_count = [0] + + async def slow_join_handler(api_key, api_version, correlation_id, request_bytes): + join_request_count[0] += 1 + # Block until the test releases the future, simulating a broker + # that's holding JoinGroup waiting for other members to rejoin. + await join_response_pending + return _join_response_object( + leader='leader-x', member_id='member-1', members=[]) + + broker.respond_fn(JoinGroupRequest, slow_join_handler) + + # First call: 50ms timer must expire and return False quickly. If the + # await on JoinGroup is not timer-aware, this call hangs until the + # connection's request_timeout_ms fires (~5s in the fixture). + start = time.monotonic() + result = seeded_coord._manager.run(seeded_coord.join_group_async, 50) + elapsed = time.monotonic() - start + assert result is False + assert elapsed < 1.0, ( + 'join_group_async did not respect timer.timeout_ms; took %.2fs' + % elapsed) + assert join_request_count[0] == 1 + + # Second call: broker is still hanging. Should reuse the cached + # in-flight task instead of sending a duplicate JoinGroup. + start = time.monotonic() + result = seeded_coord._manager.run(seeded_coord.join_group_async, 50) + elapsed = time.monotonic() - start + assert result is False + assert elapsed < 1.0 + assert join_request_count[0] == 1, ( + 'duplicate JoinGroup sent; cached task was not reused') + + # Release the broker; next call should complete using the cached task. + broker.respond(SyncGroupRequest, _sync_response_object( + assignment=ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'').encode())) + join_response_pending.success(None) + + result = seeded_coord._manager.run(seeded_coord.join_group_async, 5000) + assert result is True + assert join_request_count[0] == 1, ( + 'duplicate JoinGroup sent on the success path') + assert seeded_coord.state == MemberState.STABLE + + +@pytest.mark.parametrize("broker", [(0, 8, 0)], indirect=True) +def test_join_group_async_unsupported_version(broker, coordinator): + with pytest.raises(Errors.UnsupportedVersionError): + coordinator._manager.run(coordinator.join_group_async, None) + + +def test_ensure_active_group_async_happy_path(request, broker, seeded_coord): + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.rejoin_needed = True + seeded_coord.state = MemberState.UNJOINED + broker.respond(JoinGroupRequest, _join_response_object( + leader='leader-x', member_id='member-1', members=[])) + broker.respond(SyncGroupRequest, _sync_response_object( + assignment=ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'').encode())) + + result = seeded_coord._manager.run(seeded_coord.ensure_active_group_async, 5000) + + assert result is True + assert seeded_coord.state == MemberState.STABLE + # Heartbeat loop coroutine was scheduled. + assert seeded_coord._heartbeat_loop_future is not None + + +def test_ensure_active_group_sync_facade(request, broker, seeded_coord): + """The sync ensure_active_group facade dispatches via manager.run.""" + request.addfinalizer(lambda: setattr(seeded_coord, 'state', MemberState.UNJOINED)) + seeded_coord.rejoin_needed = True + seeded_coord.state = MemberState.UNJOINED + broker.respond(JoinGroupRequest, _join_response_object( + leader='leader-x', member_id='member-1', members=[])) + broker.respond(SyncGroupRequest, _sync_response_object( + assignment=ConsumerProtocolAssignment(0, [('foobar', [0, 1])], b'').encode())) + + result = seeded_coord.ensure_active_group(timeout_ms=5000) + + assert result is True + assert seeded_coord.state == MemberState.STABLE + + def test_heartbeat(mocker, coordinator): coordinator.coordinator_id = 0 coordinator.state = MemberState.STABLE @@ -732,16 +1270,3 @@ def test_lookup_coordinator_failure(mocker, coordinator): return_value=Future().failure(Exception('foobar'))) future = coordinator.lookup_coordinator() assert future.failed() - - -def test_ensure_active_group(mocker, coordinator): - coordinator._subscription.subscribe(topics=['foobar']) - mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) - mocker.patch.object(coordinator, '_send_join_group_request', return_value=Future().success(True)) - mocker.patch.object(coordinator, 'need_rejoin', side_effect=[True, False]) - mocker.patch.object(coordinator, '_on_join_complete') - mocker.patch.object(coordinator, '_heartbeat_thread') - - coordinator.ensure_active_group() - - coordinator._send_join_group_request.assert_called_once_with()