Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "breaking",
"description": "Replace `FieldPosition` enum values with string literals for `Field.kind`. Use \"header\" and \"trailer\" instead of `FieldPosition.HEADER` and `FieldPosition.TRAILER`."
}
12 changes: 8 additions & 4 deletions packages/aws-sdk-signers/src/aws_sdk_signers/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property
from typing import TypedDict
from typing import TypedDict, get_args
from urllib.parse import urlunparse

import aws_sdk_signers.interfaces.http as interfaces_http

_VALID_FIELD_POSITIONS = frozenset(get_args(interfaces_http.FieldPosition))


class Field(interfaces_http.Field):
"""A name-value pair representing a single field in an HTTP Request or Response.
Expand All @@ -36,10 +38,12 @@ def __init__(
*,
name: str,
values: Iterable[str] | None = None,
kind: interfaces_http.FieldPosition = interfaces_http.FieldPosition.HEADER,
kind: interfaces_http.FieldPosition = "header",
):
self.name = name
self.values: list[str] = list(values) if values is not None else []
if kind not in _VALID_FIELD_POSITIONS:
raise ValueError(f"Unknown field kind: {kind!r}")
self.kind = kind

def add(self, value: str) -> None:
Expand Down Expand Up @@ -92,7 +96,7 @@ def __eq__(self, other: object) -> bool:
return False
return (
self.name == other.name
and self.kind is other.kind
and self.kind == other.kind
and self.values == other.values
)

Expand Down Expand Up @@ -168,7 +172,7 @@ def get_by_type(

Used to grab all headers or all trailers.
"""
return [entry for entry in self.entries.values() if entry.kind is kind]
return [entry for entry in self.entries.values() if entry.kind == kind]

def extend(self, other: interfaces_http.Fields) -> None:
"""Merges ``entries`` of ``other`` into the current ``entries``.
Expand Down
29 changes: 9 additions & 20 deletions packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,18 @@

from collections import OrderedDict
from collections.abc import AsyncIterable, Iterable, Iterator
from enum import Enum
from typing import Protocol, runtime_checkable
from typing import Literal, Protocol, runtime_checkable

FieldPosition = Literal["header", "trailer"]
"""The type of a field.

class FieldPosition(Enum):
"""The type of a field.
Defines its placement in a request or response.

Defines its placement in a request or response.
"""

HEADER = 0
"""Header field.

In HTTP this is a header as defined in RFC 9110 Section 6.3. Implementations of
other protocols may use this FieldPosition for similar types of metadata.
"""

TRAILER = 1
"""Trailer field.
header: Header field. In HTTP this is a header as defined in RFC 9110 Section 6.3.
trailer: Trailer field. In HTTP this is a trailer as defined in RFC 9110 Section 6.5.

In HTTP this is a trailer as defined in RFC 9110 Section 6.5. Implementations of
other protocols may use this FieldPosition for similar types of metadata.
"""
Implementations of other protocols may use this FieldPosition for similar types of metadata.
"""


class Field(Protocol):
Expand All @@ -43,7 +32,7 @@ class Field(Protocol):

name: str
values: list[str]
kind: FieldPosition = FieldPosition.HEADER
kind: FieldPosition = "header"

def add(self, value: str) -> None:
"""Append a value to a field."""
Expand Down
46 changes: 25 additions & 21 deletions packages/aws-sdk-signers/tests/unit/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@

import pytest
from aws_sdk_signers import Field, Fields
from aws_sdk_signers.interfaces.http import FieldPosition


def test_field_single_valued_basics() -> None:
field = Field(name="fname", values=["fval"], kind=FieldPosition.HEADER)
field = Field(name="fname", values=["fval"], kind="header")
assert field.name == "fname"
assert field.kind == FieldPosition.HEADER
assert field.kind == "header"
assert field.values == ["fval"]
assert field.as_string() == "fval"
assert field.as_tuples() == [("fname", "fval")]


def test_field_multi_valued_basics() -> None:
field = Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.HEADER)
field = Field(name="fname", values=["fval1", "fval2"], kind="header")
assert field.name == "fname"
assert field.kind == FieldPosition.HEADER
assert field.kind == "header"
assert field.values == ["fval1", "fval2"]
assert field.as_string() == "fval1,fval2"
assert field.as_tuples() == [("fname", "fval1"), ("fname", "fval2")]
Expand Down Expand Up @@ -62,16 +61,16 @@ def test_field_serialization(values: list[str], expected: str) -> None:
"field,expected_repr",
[
(
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.HEADER),
"Field(name='fname', value=['fval1', 'fval2'], kind=<FieldPosition.HEADER: 0>)",
Field(name="fname", values=["fval1", "fval2"], kind="header"),
"Field(name='fname', value=['fval1', 'fval2'], kind='header')",
),
(
Field(name="fname", kind=FieldPosition.TRAILER),
"Field(name='fname', value=[], kind=<FieldPosition.TRAILER: 1>)",
Field(name="fname", kind="trailer"),
"Field(name='fname', value=[], kind='trailer')",
),
(
Field(name="fname"),
"Field(name='fname', value=[], kind=<FieldPosition.HEADER: 0>)",
"Field(name='fname', value=[], kind='header')",
),
],
)
Expand All @@ -83,8 +82,8 @@ def test_field_repr(field: Field, expected_repr: str) -> None:
"f1,f2",
[
(
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.TRAILER),
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.TRAILER),
Field(name="fname", values=["fval1", "fval2"], kind="trailer"),
Field(name="fname", values=["fval1", "fval2"], kind="trailer"),
),
(
Field(name="fname", values=["fval1", "fval2"]),
Expand All @@ -104,20 +103,20 @@ def test_field_equality(f1: Field, f2: Field) -> None:
"f1,f2",
[
(
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.HEADER),
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.TRAILER),
Field(name="fname", values=["fval1", "fval2"], kind="header"),
Field(name="fname", values=["fval1", "fval2"], kind="trailer"),
),
(
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.HEADER),
Field(name="fname", values=["fval2", "fval1"], kind=FieldPosition.HEADER),
Field(name="fname", values=["fval1", "fval2"], kind="header"),
Field(name="fname", values=["fval2", "fval1"], kind="header"),
),
(
Field(name="fname", values=["fval1", "fval2"], kind=FieldPosition.HEADER),
Field(name="fname", values=["fval1"], kind=FieldPosition.HEADER),
Field(name="fname", values=["fval1", "fval2"], kind="header"),
Field(name="fname", values=["fval1"], kind="header"),
),
(
Field(name="fname1", values=["fval1", "fval2"], kind=FieldPosition.HEADER),
Field(name="fname2", values=["fval1", "fval2"], kind=FieldPosition.HEADER),
Field(name="fname1", values=["fval1", "fval2"], kind="header"),
Field(name="fname2", values=["fval1", "fval2"], kind="header"),
),
],
)
Expand Down Expand Up @@ -211,7 +210,7 @@ def test_fields_length_value(fields: Fields, expected_length: int) -> None:
Fields([Field(name="fname1")]),
(
"Fields(OrderedDict({'fname1': Field(name='fname1', value=[], "
"kind=<FieldPosition.HEADER: 0>)}))"
"kind='header')}))"
),
),
],
Expand Down Expand Up @@ -314,3 +313,8 @@ def test_fields_delitem_missing() -> None:
fields = Fields([Field(name="fname1")])
with pytest.raises(KeyError):
del fields["fname2"]


def test_field_invalid_kind() -> None:
with pytest.raises(ValueError, match="Unknown field kind"):
Field(name="fname", kind="metadata") # type: ignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "breaking",
"description": "Replace `FieldPosition` enum values with string literals for `Field.kind`. Use \"header\" and \"trailer\" instead of `FieldPosition.HEADER` and `FieldPosition.TRAILER`."
}
15 changes: 9 additions & 6 deletions packages/smithy-http/src/smithy_http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
from collections import Counter, OrderedDict
from collections.abc import Iterable, Iterator
from typing import get_args

from . import interfaces
from .interfaces import FieldPosition

__version__ = "0.3.1"

_VALID_FIELD_POSITIONS = frozenset(get_args(FieldPosition))


class Field(interfaces.Field):
"""A name-value pair representing a single field in an HTTP Request or Response.
Expand All @@ -24,10 +27,12 @@ def __init__(
*,
name: str,
values: Iterable[str] | None = None,
kind: FieldPosition = FieldPosition.HEADER,
kind: FieldPosition = "header",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion here as with aws-sdk-signers: We should add some runtime validation.

):
self.name = name
self.values: list[str] = list(values) if values is not None else []
if kind not in _VALID_FIELD_POSITIONS:
raise ValueError(f"Unknown field kind: {kind!r}")
self.kind = kind

def add(self, value: str) -> None:
Expand Down Expand Up @@ -79,7 +84,7 @@ def __eq__(self, other: object) -> bool:
return False
return (
self.name == other.name
and self.kind is other.kind
and self.kind == other.kind
and self.values == other.values
)

Expand Down Expand Up @@ -153,7 +158,7 @@ def get_by_type(self, kind: FieldPosition) -> list[interfaces.Field]:

Used to grab all headers or all trailers.
"""
return [entry for entry in self.entries.values() if entry.kind is kind]
return [entry for entry in self.entries.values() if entry.kind == kind]

def extend(self, other: interfaces.Fields) -> None:
"""Merges ``entries`` of ``other`` into the current ``entries``.
Expand Down Expand Up @@ -225,8 +230,6 @@ def tuples_to_fields(
try:
fields[name].add(value)
except KeyError:
fields[name] = Field(
name=name, values=[value], kind=kind or FieldPosition.HEADER
)
fields[name] = Field(name=name, values=[value], kind=kind or "header")

return fields
3 changes: 1 addition & 2 deletions packages/smithy-http/src/smithy_http/aio/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from .. import Field, Fields
from ..interfaces import (
FieldPosition,
HTTPClientConfiguration,
HTTPRequestConfiguration,
)
Expand Down Expand Up @@ -125,7 +124,7 @@ async def _marshal_response(
headers[header_name] = Field(
name=header_name,
values=[header_val],
kind=FieldPosition.HEADER,
kind="header",
)

return HTTPResponse(
Expand Down
6 changes: 2 additions & 4 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .. import Field, Fields
from .. import interfaces as http_interfaces
from ..exceptions import SmithyHTTPError
from ..interfaces import FieldPosition
from . import interfaces as http_aio_interfaces

# Default buffer size for reading from streams (8 KB)
Expand Down Expand Up @@ -203,7 +202,7 @@ async def _await_response(
fields[header_name] = Field(
name=header_name,
values=[header_val],
kind=FieldPosition.HEADER,
kind="header",
)
return AWSCRTHTTPResponse(
status=status_code,
Expand Down Expand Up @@ -294,8 +293,7 @@ def _marshal_request(
request.fields.set_field(Field(name="accept", values=["*/*"]))

for fld in request.fields.entries.values():
# TODO: Use literal values for "header"/"trailer".
if fld.kind.value != FieldPosition.HEADER.value:
if fld.kind != "header":
continue
for val in fld.values:
headers_list.append((fld.name, val))
Expand Down
29 changes: 9 additions & 20 deletions packages/smithy-http/src/smithy_http/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum
from typing import Protocol
from typing import Literal, Protocol

FieldPosition = Literal["header", "trailer"]
"""The type of a field.

class FieldPosition(Enum):
"""The type of a field.
Defines its placement in a request or response.

Defines its placement in a request or response.
"""

HEADER = 0
"""Header field.

In HTTP this is a header as defined in RFC 9110 Section 6.3. Implementations of
other protocols may use this FieldPosition for similar types of metadata.
"""

TRAILER = 1
"""Trailer field.
header: Header field. In HTTP this is a header as defined in RFC 9110 Section 6.3.
trailer: Trailer field. In HTTP this is a trailer as defined in RFC 9110 Section 6.5.

In HTTP this is a trailer as defined in RFC 9110 Section 6.5. Implementations of
other protocols may use this FieldPosition for similar types of metadata.
"""
Implementations of other protocols may use this FieldPosition for similar types of metadata.
"""


class Field(Protocol):
Expand All @@ -40,7 +29,7 @@ class Field(Protocol):

name: str
values: list[str]
kind: FieldPosition = FieldPosition.HEADER
kind: FieldPosition = "header"

def add(self, value: str) -> None:
"""Append a value to a field."""
Expand Down
Loading
Loading