diff --git a/kafka/protocol/admin/topics.py b/kafka/protocol/admin/topics.py index cc2962c13..83afe6e8e 100644 --- a/kafka/protocol/admin/topics.py +++ b/kafka/protocol/admin/topics.py @@ -29,6 +29,9 @@ class AlterPartitionReassignmentsResponse(ApiMessage): pass class ListPartitionReassignmentsRequest(ApiMessage): pass class ListPartitionReassignmentsResponse(ApiMessage): pass +class DescribeTopicPartitionsRequest(ApiMessage): pass +class DescribeTopicPartitionsResponse(ApiMessage): pass + class DeleteRecordsRequest(ApiMessage): pass class DeleteRecordsResponse(ApiMessage): pass @@ -48,6 +51,7 @@ class ElectionType(IntEnum): 'AlterPartitionRequest', 'AlterPartitionResponse', 'AlterPartitionReassignmentsRequest', 'AlterPartitionReassignmentsResponse', 'ListPartitionReassignmentsRequest', 'ListPartitionReassignmentsResponse', + 'DescribeTopicPartitionsRequest', 'DescribeTopicPartitionsResponse', 'DeleteRecordsRequest', 'DeleteRecordsResponse', 'ElectLeadersRequest', 'ElectLeadersResponse', 'ElectionType', ] diff --git a/kafka/protocol/admin/topics.pyi b/kafka/protocol/admin/topics.pyi index 5d922d830..ceda4b494 100644 --- a/kafka/protocol/admin/topics.pyi +++ b/kafka/protocol/admin/topics.pyi @@ -6,7 +6,7 @@ from enum import IntEnum from kafka.protocol.api_message import ApiMessage from kafka.protocol.data_container import DataContainer -__all__ = ['CreateTopicsRequest', 'CreateTopicsResponse', 'DeleteTopicsRequest', 'DeleteTopicsResponse', 'CreatePartitionsRequest', 'CreatePartitionsResponse', 'AlterPartitionRequest', 'AlterPartitionResponse', 'AlterPartitionReassignmentsRequest', 'AlterPartitionReassignmentsResponse', 'ListPartitionReassignmentsRequest', 'ListPartitionReassignmentsResponse', 'DeleteRecordsRequest', 'DeleteRecordsResponse', 'ElectLeadersRequest', 'ElectLeadersResponse', 'ElectionType'] +__all__ = ['CreateTopicsRequest', 'CreateTopicsResponse', 'DeleteTopicsRequest', 'DeleteTopicsResponse', 'CreatePartitionsRequest', 'CreatePartitionsResponse', 'AlterPartitionRequest', 'AlterPartitionResponse', 'AlterPartitionReassignmentsRequest', 'AlterPartitionReassignmentsResponse', 'ListPartitionReassignmentsRequest', 'ListPartitionReassignmentsResponse', 'DescribeTopicPartitionsRequest', 'DescribeTopicPartitionsResponse', 'DeleteRecordsRequest', 'DeleteRecordsResponse', 'ElectLeadersRequest', 'ElectLeadersResponse', 'ElectionType'] class CreateTopicsRequest(ApiMessage): class CreatableTopic(DataContainer): @@ -746,6 +746,161 @@ class ListPartitionReassignmentsResponse(ApiMessage): def expect_response(self) -> bool: ... def with_header(self, correlation_id: int = 0, client_id: str = "kafka-python") -> None: ... +class DescribeTopicPartitionsRequest(ApiMessage): + class TopicRequest(DataContainer): + name: str + def __init__( + self, + *args: Any, + name: str = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + + class Cursor(DataContainer): + topic_name: str + partition_index: int + def __init__( + self, + *args: Any, + topic_name: str = ..., + partition_index: int = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + + topics: list[TopicRequest] + response_partition_limit: int + cursor: Cursor | None + def __init__( + self, + *args: Any, + topics: list[TopicRequest] = ..., + response_partition_limit: int = ..., + cursor: Cursor | None = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + name: str + type: str + API_KEY: int + API_VERSION: int + valid_versions: tuple[int, int] + min_version: int + max_version: int + @property + def header(self) -> Any: ... + @classmethod + def is_request(cls) -> bool: ... + def expect_response(self) -> bool: ... + def with_header(self, correlation_id: int = 0, client_id: str = "kafka-python") -> None: ... + +class DescribeTopicPartitionsResponse(ApiMessage): + class DescribeTopicPartitionsResponseTopic(DataContainer): + class DescribeTopicPartitionsResponsePartition(DataContainer): + error_code: int + partition_index: int + leader_id: int + leader_epoch: int + replica_nodes: list[int] + isr_nodes: list[int] + eligible_leader_replicas: list[int] | None + last_known_elr: list[int] | None + offline_replicas: list[int] + def __init__( + self, + *args: Any, + error_code: int = ..., + partition_index: int = ..., + leader_id: int = ..., + leader_epoch: int = ..., + replica_nodes: list[int] = ..., + isr_nodes: list[int] = ..., + eligible_leader_replicas: list[int] | None = ..., + last_known_elr: list[int] | None = ..., + offline_replicas: list[int] = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + + error_code: int + name: str | None + topic_id: uuid.UUID + is_internal: bool + partitions: list[DescribeTopicPartitionsResponsePartition] + topic_authorized_operations: int + def __init__( + self, + *args: Any, + error_code: int = ..., + name: str | None = ..., + topic_id: uuid.UUID = ..., + is_internal: bool = ..., + partitions: list[DescribeTopicPartitionsResponsePartition] = ..., + topic_authorized_operations: int = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + + class Cursor(DataContainer): + topic_name: str + partition_index: int + def __init__( + self, + *args: Any, + topic_name: str = ..., + partition_index: int = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + + throttle_time_ms: int + topics: list[DescribeTopicPartitionsResponseTopic] + next_cursor: Cursor | None + def __init__( + self, + *args: Any, + throttle_time_ms: int = ..., + topics: list[DescribeTopicPartitionsResponseTopic] = ..., + next_cursor: Cursor | None = ..., + version: int | None = None, + **kwargs: Any, + ) -> None: ... + @property + def version(self) -> int | None: ... + def to_dict(self, meta: bool = False, json: bool = True) -> dict: ... + name: str + type: str + API_KEY: int + API_VERSION: int + valid_versions: tuple[int, int] + min_version: int + max_version: int + @property + def header(self) -> Any: ... + @classmethod + def is_request(cls) -> bool: ... + def expect_response(self) -> bool: ... + def with_header(self, correlation_id: int = 0, client_id: str = "kafka-python") -> None: ... + class DeleteRecordsRequest(ApiMessage): class DeleteRecordsTopic(DataContainer): class DeleteRecordsPartition(DataContainer): diff --git a/kafka/protocol/data_container.py b/kafka/protocol/data_container.py index 5e0060d2b..6d4200f21 100644 --- a/kafka/protocol/data_container.py +++ b/kafka/protocol/data_container.py @@ -137,9 +137,11 @@ def _to_dict_vals(self, meta=False, json=True): if self._version is not None and not field.for_version_q(self._version): continue if field.is_struct(): - yield (field.name, dict(getattr(self, field.name)._to_dict_vals(meta=meta, json=json))) + val = getattr(self, field.name) + yield (field.name, None if val is None else dict(val._to_dict_vals(meta=meta, json=json))) elif field.is_struct_array(): - yield (field.name, [dict(val._to_dict_vals(meta=meta, json=json)) for val in getattr(self, field.name)]) + val = getattr(self, field.name) + yield (field.name, None if val is None else [dict(v._to_dict_vals(meta=meta, json=json)) for v in val]) else: val = getattr(self, field.name) yield (field.name, field.to_json(val) if json else val) diff --git a/kafka/protocol/schemas/fields/base.py b/kafka/protocol/schemas/fields/base.py index 1f11430eb..c683208ab 100644 --- a/kafka/protocol/schemas/fields/base.py +++ b/kafka/protocol/schemas/fields/base.py @@ -113,6 +113,11 @@ def max_version(self): def for_version_q(self, version): return self._versions[0] <= version <= self._versions[1] + def nullable_for_version_q(self, version): + if self._nullable_versions is None: + return False + return self._nullable_versions[0] <= version <= self._nullable_versions[1] + def tagged_field_q(self, version): if self._tag is None or self._tagged_versions is None: return False diff --git a/kafka/protocol/schemas/fields/struct.py b/kafka/protocol/schemas/fields/struct.py index d441de225..2147a361e 100644 --- a/kafka/protocol/schemas/fields/struct.py +++ b/kafka/protocol/schemas/fields/struct.py @@ -72,6 +72,15 @@ def untagged_fields(self, version): return self._untagged_fields_cache[version] def encode(self, item, version=None, compact=False, tagged=False): + # Nested nullable struct: 1-byte prefix (0 = null, 1 = present). + # Top-level message structs never have nullableVersions set, so this + # check is safe without a top-level guard. + if self.nullable_for_version_q(version): + if item is None: + return b'\x00' + prefix = b'\x01' + else: + prefix = b'' fields = self.untagged_fields(version) if isinstance(item, tuple): getter = lambda item, i, field: item[i] @@ -94,10 +103,26 @@ def encode(self, item, version=None, compact=False, tagged=False): encoded.append(self.tagged_fields(version).encode(tags, version=version)) elif tagged is None: encoded.append(TaggedFields.encode_empty()) - return b''.join(encoded) + return prefix + b''.join(encoded) def emit_encode_into(self, ctx, item_expr, indent, version=None, compact=False, tagged=False, tuple_access=False): + # Top-level struct (item_expr == 'item') has its nullability handled + # by the parent struct; only inline null-prefix when this is a nested + # nullable struct field. + inline_nullable = ( + self.nullable_for_version_q(version) + and item_expr != 'item' + and not tuple_access + ) + if inline_nullable: + ctx.emit(indent, 'if %s is None:' % item_expr) + ctx.emit(indent, ' buf[pos] = 0') + ctx.emit(indent, ' pos += 1') + ctx.emit(indent, 'else:') + ctx.emit(indent, ' buf[pos] = 1') + ctx.emit(indent, ' pos += 1') + indent = indent + ' ' fields = self.untagged_fields(version) for i, field in enumerate(fields): if tuple_access: @@ -117,6 +142,14 @@ def emit_encode_into(self, ctx, item_expr, indent, version=None, compact=False, ctx.emit(indent, 'pos += 1') def encode_into(self, item, out, version=None, compact=False, tagged=False): + if self.nullable_for_version_q(version): + out.ensure(1) + if item is None: + out.buf[out.pos] = 0 + out.pos += 1 + return + out.buf[out.pos] = 1 + out.pos += 1 fields = self.untagged_fields(version) if isinstance(item, tuple): for i, field in enumerate(fields): @@ -162,6 +195,19 @@ def emit_decode_from(self, ctx, var_name, indent, version=None, compact=False, t Batches adjacent batchable fields into single unpack_from calls. """ + # Top-level struct decode (var_name == 'obj') has no outer null-prefix; + # only a nested nullable struct field consumes one. + inline_nullable = ( + self.nullable_for_version_q(version) + and var_name != 'obj' + ) + if inline_nullable: + ctx.emit(indent, 'if data[pos] == 0:') + ctx.emit(indent, ' pos += 1') + ctx.emit(indent, ' %s = None' % var_name) + ctx.emit(indent, 'else:') + ctx.emit(indent, ' pos += 1') + indent = indent + ' ' fields = self.untagged_fields(version) data_class = self.data_class @@ -279,6 +325,9 @@ def compiled_decode_from(self, version, compact=False, tagged=False, data_class= return self._compiled_encoders[key] def decode(self, data, version=None, compact=False, tagged=False, data_class=None): + if self.nullable_for_version_q(version): + if data.read(1) == b'\x00': + return None if data_class is None: data_class = self.data_class decoded = { diff --git a/kafka/protocol/schemas/fields/struct_array.py b/kafka/protocol/schemas/fields/struct_array.py index 39d63280f..376041b16 100644 --- a/kafka/protocol/schemas/fields/struct_array.py +++ b/kafka/protocol/schemas/fields/struct_array.py @@ -26,6 +26,10 @@ def __init__(self, json, array_of=None): array_of = self.parse_inner_type(json) assert array_of is not None, 'json does not contain a StructArray!' super().__init__(json, array_of=array_of) + # nullableVersions on the JSON describes the array's nullability, not + # the inner struct's. Clear it on the inner struct so StructField does + # not try to emit a per-element null-prefix when encoding/decoding. + array_of._nullable_versions = None # map_key will be (idx, field) of the mapKey field if found self.map_key = next(filter(lambda x: x[1]._json.get('mapKey'), enumerate(self._fields)), None) diff --git a/kafka/protocol/schemas/resources/DescribeTopicPartitionsRequest.json b/kafka/protocol/schemas/resources/DescribeTopicPartitionsRequest.json new file mode 100644 index 000000000..fa79989ff --- /dev/null +++ b/kafka/protocol/schemas/resources/DescribeTopicPartitionsRequest.json @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 75, + "type": "request", + "listeners": ["broker"], + "name": "DescribeTopicPartitionsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Topics", "type": "[]TopicRequest", "versions": "0+", + "about": "The topics to fetch details for.", + "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." } + ] + }, + { "name": "ResponsePartitionLimit", "type": "int32", "versions": "0+", "default": "2000", + "about": "The maximum number of partitions included in the response." }, + { "name": "Cursor", "type": "Cursor", "versions": "0+", "nullableVersions": "0+", "default": "null", + "about": "The first topic and partition index to fetch details for.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name for the first topic to process." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "about": "The partition index to start with." } + ]} + ] +} diff --git a/kafka/protocol/schemas/resources/DescribeTopicPartitionsResponse.json b/kafka/protocol/schemas/resources/DescribeTopicPartitionsResponse.json new file mode 100644 index 000000000..668c85431 --- /dev/null +++ b/kafka/protocol/schemas/resources/DescribeTopicPartitionsResponse.json @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 75, + "type": "response", + "name": "DescribeTopicPartitionsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]DescribeTopicPartitionsResponseTopic", "versions": "0+", + "about": "Each topic in the response.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The topic error, or 0 if there was no error." }, + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", "nullableVersions": "0+", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "0+", "ignorable": true, "about": "The topic id." }, + { "name": "IsInternal", "type": "bool", "versions": "0+", "default": "false", "ignorable": true, + "about": "True if the topic is internal." }, + { "name": "Partitions", "type": "[]DescribeTopicPartitionsResponsePartition", "versions": "0+", + "about": "Each partition in the topic.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition error, or 0 if there was no error." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the leader broker." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", "default": "-1", "ignorable": true, + "about": "The leader epoch of this partition." }, + { "name": "ReplicaNodes", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The set of all nodes that host this partition." }, + { "name": "IsrNodes", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The set of nodes that are in sync with the leader for this partition." }, + { "name": "EligibleLeaderReplicas", "type": "[]int32", "default": "null", "entityType": "brokerId", + "versions": "0+", "nullableVersions": "0+", + "about": "The new eligible leader replicas otherwise." }, + { "name": "LastKnownElr", "type": "[]int32", "default": "null", "entityType": "brokerId", + "versions": "0+", "nullableVersions": "0+", + "about": "The last known ELR." }, + { "name": "OfflineReplicas", "type": "[]int32", "versions": "0+", "ignorable": true, "entityType": "brokerId", + "about": "The set of offline replicas of this partition." }]}, + { "name": "TopicAuthorizedOperations", "type": "int32", "versions": "0+", "default": "-2147483648", + "about": "32-bit bitfield to represent authorized operations for this topic." }] + }, + { "name": "NextCursor", "type": "Cursor", "versions": "0+", "nullableVersions": "0+", "default": "null", + "about": "The next topic and partition index to fetch details for.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name for the first topic to process." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "about": "The partition index to start with." } + ]} + ] +} diff --git a/test/protocol/admin/test_protocol_admin_partitions.py b/test/protocol/admin/test_protocol_admin_partitions.py new file mode 100644 index 000000000..086b67178 --- /dev/null +++ b/test/protocol/admin/test_protocol_admin_partitions.py @@ -0,0 +1,87 @@ +import uuid + +import pytest + +from kafka.protocol.admin import ( + DescribeTopicPartitionsRequest, DescribeTopicPartitionsResponse, +) + + +def _versions(cls): + lo, hi = cls._valid_versions + return range(lo, hi + 1) + + +@pytest.mark.parametrize("version", _versions(DescribeTopicPartitionsRequest)) +def test_describe_topic_partitions_request_roundtrip(version): + Topic = DescribeTopicPartitionsRequest.TopicRequest + Cursor = DescribeTopicPartitionsRequest.Cursor + request = DescribeTopicPartitionsRequest( + topics=[Topic(name='topic-a'), Topic(name='topic-b')], + response_partition_limit=1000, + cursor=Cursor(topic_name='topic-a', partition_index=5), + ) + encoded = request.encode(version=version) + decoded = DescribeTopicPartitionsRequest.decode(encoded, version=version) + assert decoded == request + + +@pytest.mark.parametrize("version", _versions(DescribeTopicPartitionsRequest)) +def test_describe_topic_partitions_request_null_cursor(version): + Topic = DescribeTopicPartitionsRequest.TopicRequest + request = DescribeTopicPartitionsRequest( + topics=[Topic(name='topic-a')], + response_partition_limit=2000, + cursor=None, + ) + encoded = request.encode(version=version) + decoded = DescribeTopicPartitionsRequest.decode(encoded, version=version) + assert decoded == request + + +@pytest.mark.parametrize("version", _versions(DescribeTopicPartitionsResponse)) +def test_describe_topic_partitions_response_roundtrip(version): + Topic = DescribeTopicPartitionsResponse.DescribeTopicPartitionsResponseTopic + Partition = Topic.DescribeTopicPartitionsResponsePartition + Cursor = DescribeTopicPartitionsResponse.Cursor + response = DescribeTopicPartitionsResponse( + throttle_time_ms=0, + topics=[ + Topic( + error_code=0, + name='topic-a', + topic_id=uuid.uuid4(), + is_internal=False, + partitions=[ + Partition( + error_code=0, + partition_index=0, + leader_id=1, + leader_epoch=5, + replica_nodes=[1, 2, 3], + isr_nodes=[1, 2], + eligible_leader_replicas=[3], + last_known_elr=[2], + offline_replicas=[], + ), + ], + topic_authorized_operations=-2147483648, + ), + ], + next_cursor=Cursor(topic_name='topic-a', partition_index=1), + ) + encoded = response.encode(version=version) + decoded = DescribeTopicPartitionsResponse.decode(encoded, version=version) + assert decoded == response + + +@pytest.mark.parametrize("version", _versions(DescribeTopicPartitionsResponse)) +def test_describe_topic_partitions_response_null_cursor(version): + response = DescribeTopicPartitionsResponse( + throttle_time_ms=0, + topics=[], + next_cursor=None, + ) + encoded = response.encode(version=version) + decoded = DescribeTopicPartitionsResponse.decode(encoded, version=version) + assert decoded == response diff --git a/test/protocol/schemas/test_encode_parity.py b/test/protocol/schemas/test_encode_parity.py new file mode 100644 index 000000000..ebc0a98c4 --- /dev/null +++ b/test/protocol/schemas/test_encode_parity.py @@ -0,0 +1,352 @@ +"""Verify that StructField.compiled_encode_into (codegen) matches the +non-optimized StructField.encode() output byte-for-byte, across a variety +of schemas and versions. Does the same for compiled_decode_from vs the +reference decode(). + +These tests exercise the nullable-struct null-prefix handling added to +emit_encode_into/emit_decode_from (and the matching reference paths in +encode/encode_into/decode), plus the StructArrayField._nullable_versions +reset that protects non-nullable struct-array elements. +""" +import io +import uuid + +import pytest + +from kafka.protocol.admin import ( + AlterPartitionReassignmentsRequest, + DescribeTopicPartitionsRequest, + DescribeTopicPartitionsResponse, + ListPartitionReassignmentsResponse, +) +from kafka.protocol.metadata import MetadataRequest, MetadataResponse +from kafka.protocol.schemas.fields.base import BaseField +from kafka.protocol.schemas.fields.codecs.encode_buffer import EncodeBuffer +from kafka.protocol.data_container import DataContainer + + +def _reference_encode(api_message): + """Encode using the non-optimized struct.encode() reference path.""" + flexible = api_message.flexible_version_q(api_message.API_VERSION) + return api_message._struct.encode( + api_message, version=api_message.API_VERSION, + compact=flexible, tagged=flexible) + + +def _codegen_encode(api_message): + """Encode using the compiled codegen path (what production uses).""" + flexible = api_message.flexible_version_q(api_message.API_VERSION) + fn = api_message._struct.compiled_encode_into( + api_message.API_VERSION, compact=flexible, tagged=flexible) + out = EncodeBuffer() + fn(api_message, out) + return bytes(out.result()) + + +def _reference_decode(cls, data, version): + flexible = cls.flexible_version_q(version) + bio = io.BytesIO(data) + return cls._struct.decode( + bio, version=version, compact=flexible, tagged=flexible, + data_class=cls[None]) + + +def _codegen_decode(cls, data, version): + flexible = cls.flexible_version_q(version) + fn = cls._struct.compiled_decode_from( + version, compact=flexible, tagged=flexible, data_class=cls[None]) + obj, _pos = fn(memoryview(data), 0) + return obj + + +# --------------------------------------------------------------------------- +# Parity: simple struct, no nullable, no struct array +# --------------------------------------------------------------------------- + + +def test_parity_metadata_request_v0(): + req = MetadataRequest(version=0, topics=['topic-a', 'topic-b']) + assert _reference_encode(req) == _codegen_encode(req) + + +def test_parity_metadata_request_flexible(): + # v9+ is flexible with compact strings and tagged fields + Topic = MetadataRequest.MetadataRequestTopic + req = MetadataRequest( + version=12, + topics=[Topic(topic_id=uuid.uuid4(), name='t1')], + allow_auto_topic_creation=True, + include_topic_authorized_operations=False, + ) + assert _reference_encode(req) == _codegen_encode(req) + + +def test_parity_metadata_request_null_topics(): + # ALL_TOPICS is encoded as null topics array + req = MetadataRequest(version=4, topics=None, + allow_auto_topic_creation=False) + assert _reference_encode(req) == _codegen_encode(req) + + +# --------------------------------------------------------------------------- +# Parity: struct arrays with nested scalar fields +# --------------------------------------------------------------------------- + + +def test_parity_metadata_response_with_topics(): + Broker = MetadataResponse.MetadataResponseBroker + Topic = MetadataResponse.MetadataResponseTopic + Partition = Topic.MetadataResponsePartition + resp = MetadataResponse( + version=7, + throttle_time_ms=0, + brokers=[Broker(node_id=1, host='h1', port=9092, rack='r1')], + cluster_id='c1', + controller_id=1, + topics=[ + Topic( + error_code=0, + name='t1', + is_internal=False, + partitions=[ + Partition( + error_code=0, partition_index=0, leader_id=1, + leader_epoch=5, + replica_nodes=[1], isr_nodes=[1], + offline_replicas=[], + ), + ], + ), + ], + ) + assert _reference_encode(resp) == _codegen_encode(resp) + + +def test_parity_alter_partition_reassignments_request(): + Topic = AlterPartitionReassignmentsRequest.ReassignableTopic + Partition = Topic.ReassignablePartition + req = AlterPartitionReassignmentsRequest( + version=0, + timeout_ms=30000, + topics=[ + Topic( + name='topic-a', + partitions=[ + Partition(partition_index=0, replicas=[1, 2, 3]), + # cancel: null replicas + Partition(partition_index=1, replicas=None), + ], + ), + ], + ) + assert _reference_encode(req) == _codegen_encode(req) + + +# --------------------------------------------------------------------------- +# Parity: nullable nested struct (DescribeTopicPartitionsResponse.cursor) +# --------------------------------------------------------------------------- + + +def test_parity_describe_topic_partitions_request_cursor_present(): + _Topic = DescribeTopicPartitionsRequest.TopicRequest + _Cursor = DescribeTopicPartitionsRequest.Cursor + req = DescribeTopicPartitionsRequest( + version=0, + topics=[_Topic(name='t1'), _Topic(name='t2')], + response_partition_limit=500, + cursor=_Cursor(topic_name='t1', partition_index=3), + ) + assert _reference_encode(req) == _codegen_encode(req) + + +def test_parity_describe_topic_partitions_request_cursor_null(): + _Topic = DescribeTopicPartitionsRequest.TopicRequest + req = DescribeTopicPartitionsRequest( + version=0, + topics=[_Topic(name='t1')], + response_partition_limit=2000, + cursor=None, + ) + assert _reference_encode(req) == _codegen_encode(req) + + +def test_parity_describe_topic_partitions_response_cursor_present(): + _Topic = DescribeTopicPartitionsResponse.DescribeTopicPartitionsResponseTopic + _Cursor = DescribeTopicPartitionsResponse.Cursor + resp = DescribeTopicPartitionsResponse( + version=0, + throttle_time_ms=0, + topics=[], + next_cursor=_Cursor(topic_name='t1', partition_index=7), + ) + assert _reference_encode(resp) == _codegen_encode(resp) + + +def test_parity_describe_topic_partitions_response_cursor_null(): + resp = DescribeTopicPartitionsResponse( + version=0, throttle_time_ms=0, topics=[], next_cursor=None) + assert _reference_encode(resp) == _codegen_encode(resp) + + +def test_parity_describe_topic_partitions_response_with_elr(): + _Topic = DescribeTopicPartitionsResponse.DescribeTopicPartitionsResponseTopic + _Partition = _Topic.DescribeTopicPartitionsResponsePartition + resp = DescribeTopicPartitionsResponse( + version=0, + throttle_time_ms=0, + topics=[ + _Topic( + error_code=0, + name='t1', + topic_id=uuid.uuid4(), + is_internal=False, + partitions=[ + _Partition( + error_code=0, partition_index=0, leader_id=1, + leader_epoch=5, replica_nodes=[1, 2, 3], + isr_nodes=[1, 2], eligible_leader_replicas=[3], + last_known_elr=[2], offline_replicas=[], + ), + ], + topic_authorized_operations=0, + ), + ], + next_cursor=None, + ) + assert _reference_encode(resp) == _codegen_encode(resp) + + +# --------------------------------------------------------------------------- +# Parity: nullable error_message (flexible compact nullable string) in a +# list/response with struct arrays +# --------------------------------------------------------------------------- + + +def test_parity_list_partition_reassignments_response(): + _Topic = ListPartitionReassignmentsResponse.OngoingTopicReassignment + _Partition = _Topic.OngoingPartitionReassignment + resp = ListPartitionReassignmentsResponse( + version=0, + throttle_time_ms=0, + error_code=0, + error_message=None, + topics=[ + _Topic( + name='t1', + partitions=[ + _Partition( + partition_index=0, replicas=[1, 2, 3], + adding_replicas=[4], removing_replicas=[1]), + ], + ), + ], + ) + assert _reference_encode(resp) == _codegen_encode(resp) + + +# --------------------------------------------------------------------------- +# Decode parity: round-trip the codegen encoder output through both decoders. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize('cls,kwargs,version', [ + (DescribeTopicPartitionsRequest, + dict(topics=[DescribeTopicPartitionsRequest.TopicRequest(name='t1')], + response_partition_limit=10, + cursor=DescribeTopicPartitionsRequest.Cursor( + topic_name='t1', partition_index=2)), + 0), + (DescribeTopicPartitionsRequest, + dict(topics=[DescribeTopicPartitionsRequest.TopicRequest(name='t1')], + response_partition_limit=10, cursor=None), + 0), + (DescribeTopicPartitionsResponse, + dict(throttle_time_ms=0, topics=[], next_cursor=None), + 0), + (DescribeTopicPartitionsResponse, + dict(throttle_time_ms=0, topics=[], + next_cursor=DescribeTopicPartitionsResponse.Cursor( + topic_name='t1', partition_index=5)), + 0), + (MetadataRequest, + dict(topics=['a', 'b']), + 0), + (MetadataRequest, + dict(topics=None, allow_auto_topic_creation=False), + 4), +]) +def test_decode_parity(cls, kwargs, version): + msg = cls(version=version, **kwargs) + encoded = _codegen_encode(msg) + + via_codegen = _codegen_decode(cls, encoded, version) + via_reference = _reference_decode(cls, encoded, version) + + # Compare full field-tree via to_dict (cursor=None case exercises the + # _to_dict_vals None-guard in DataContainer). + assert via_codegen.to_dict() == via_reference.to_dict() + + +# --------------------------------------------------------------------------- +# Targeted: build a minimal struct schema with a nullable nested struct to +# verify the null-prefix byte wire format directly (independent of any real +# protocol schema). +# --------------------------------------------------------------------------- + + +def _make_struct_with_nullable_child(): + return BaseField.parse_json({ + 'name': 'Outer', + 'versions': '0+', + 'validVersions': '0', + 'flexibleVersions': '0+', + 'type': 'Outer', + 'fields': [ + {'name': 'x', 'type': 'int32', 'versions': '0+'}, + {'name': 'child', 'type': 'Child', 'versions': '0+', + 'nullableVersions': '0+', + 'fields': [ + {'name': 'y', 'type': 'int32', 'versions': '0+'}, + ]}, + ], + }) + + +def _container_from_struct(struct): + Outer = type('Outer', (DataContainer,), {'_struct': struct}) + return Outer + + +def test_nullable_child_wire_format_null(): + struct = _make_struct_with_nullable_child() + Outer = _container_from_struct(struct) + # Build inner child class via the field's data_class (set by __init_subclass__). + item = Outer(version=0, x=7, child=None) + encoded = struct.encode(item, version=0, compact=True, tagged=None) + # int32 x=7 (4 bytes), nullable struct prefix 0x00, empty tagged fields 0x00 + assert encoded == b'\x00\x00\x00\x07\x00\x00' + + +def test_nullable_child_wire_format_present(): + struct = _make_struct_with_nullable_child() + Outer = _container_from_struct(struct) + Child = struct.fields['child'].data_class + item = Outer(version=0, x=7, child=Child(version=0, y=9)) + encoded = struct.encode(item, version=0, compact=True, tagged=None) + # int32 x=7, prefix 0x01, int32 y=9, inner struct's empty tagged 0x00, + # outer struct's empty tagged 0x00 + assert encoded == b'\x00\x00\x00\x07\x01\x00\x00\x00\x09\x00\x00' + + +def test_nullable_child_codegen_matches_reference(): + struct = _make_struct_with_nullable_child() + Outer = _container_from_struct(struct) + Child = struct.fields['child'].data_class + + for child_val in (None, Child(version=0, y=42)): + item = Outer(version=0, x=1, child=child_val) + ref = struct.encode(item, version=0, compact=True, tagged=None) + fn = struct.compiled_encode_into(0, compact=True, tagged=None) + out = EncodeBuffer() + fn(item, out) + assert bytes(out.result()) == ref