diff --git a/datastore/api/responses.py b/datastore/api/responses.py index 6f9b748..543e2df 100644 --- a/datastore/api/responses.py +++ b/datastore/api/responses.py @@ -1,5 +1,7 @@ from __future__ import annotations +import base64 +from decimal import Decimal from typing import Any import orjson @@ -24,6 +26,16 @@ def _orjson_default(obj: Any) -> Any: if hasattr(obj, "model_dump"): return obj.model_dump(exclude_none=True) + # BigQuery `NUMERIC` / `BIGNUMERIC` columns come back as Decimal — + # JSON has no native form, and orjson refuses by default. Stringify + # to preserve full precision (NUMERIC = 38 digits, BIGNUMERIC = 76+, + # both beyond IEEE-754 double). + if isinstance(obj, Decimal): + return str(obj) + # `BYTES` columns come back as raw `bytes`; base64-encode so the + # response stays UTF-8 and round-trippable. + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("ascii") raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") diff --git a/datastore/services/streaming.py b/datastore/services/streaming.py index 620fb53..12df494 100644 --- a/datastore/services/streaming.py +++ b/datastore/services/streaming.py @@ -29,14 +29,39 @@ from __future__ import annotations +import base64 import csv import io from collections.abc import Iterator +from decimal import Decimal from typing import Any import orjson +def _json_default(obj: Any) -> Any: + """Serialise types `orjson` refuses out of the box. + + BigQuery `NUMERIC` / `BIGNUMERIC` columns come back as + `decimal.Decimal`, which has no native JSON representation. + Stringifying preserves full precision (NUMERIC is 38 digits, + BIGNUMERIC is 76+ — beyond what a JSON number / IEEE-754 double + can represent without loss) and matches CKAN's datastore + convention of returning high-precision numerics as strings. + + `bytes` (BigQuery `BYTES` columns) are base64-encoded so the + response stays UTF-8 and round-trippable. + """ + if isinstance(obj, Decimal): + return str(obj) + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("ascii") + raise TypeError( + f"orjson cannot serialise {type(obj).__name__}; " + "extend `_json_default` if a new BigQuery type comes through." + ) + + def stream_objects( *, help_url: str, @@ -219,7 +244,7 @@ def _records_object_array( first = False else: yield b"," - yield orjson.dumps(dict(zip(columns, row))) + yield orjson.dumps(dict(zip(columns, row)), default=_json_default) yield b"]" @@ -232,7 +257,7 @@ def _records_array_array(records: Iterator[tuple]) -> Iterator[bytes]: first = False else: yield b"," - yield orjson.dumps(list(row)) + yield orjson.dumps(list(row), default=_json_default) yield b"]" diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..4377f8b --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,81 @@ +"""Regression tests for the streaming row writers in `services.streaming`. + +Targets the BigQuery scalar types orjson refuses by default: + + - `NUMERIC` / `BIGNUMERIC` → `decimal.Decimal` + - `BYTES` → `bytes` + +The fix lives in `_json_default` (passed via `orjson.dumps(default=...)`). +Without it, the stream crashes mid-row with +`TypeError: Type is not JSON serializable: decimal.Decimal`. +""" + +from __future__ import annotations + +import json +from decimal import Decimal + +from datastore.services.streaming import ( + _records_array_array, + _records_object_array, +) + + +def _join(parts: list[bytes]) -> str: + """Stitch the yielded chunks together as a UTF-8 string.""" + return b"".join(parts).decode("utf-8") + + +def test_records_object_array_serialises_decimal_and_bytes() -> None: + """Rows with NUMERIC (Decimal) + BYTES values must stream without + blowing up; Decimal is stringified (preserves precision); bytes is + base64-encoded.""" + rows = iter( + [ + ("DCL", Decimal("47.82"), b"\x00\xff"), + ("DCH", Decimal("0.00000000000000000000000000000000000001"), b"abc"), + ] + ) + columns = ["product_code", "clearing_price_gbp_per_mwh", "signature"] + + body = _join(list(_records_object_array(columns, rows))) + records = json.loads(body) + + assert records == [ + { + "product_code": "DCL", + "clearing_price_gbp_per_mwh": "47.82", + "signature": "AP8=", # b64("\x00\xff") + }, + { + "product_code": "DCH", + "clearing_price_gbp_per_mwh": "1E-38", + "signature": "YWJj", # b64(b"abc") + }, + ] + + +def test_records_array_array_serialises_decimal_and_bytes() -> None: + """Same coverage for `records_format=lists`.""" + rows = iter([("DCL", Decimal("47.82"), b"\x00\xff")]) + + body = _join(list(_records_array_array(rows))) + records = json.loads(body) + + assert records == [["DCL", "47.82", "AP8="]] + + +def test_unsupported_type_still_raises() -> None: + """We don't want the default to silently swallow new unknown types — + bail loudly so the bug surfaces in tests instead of in production.""" + + class Mystery: + pass + + rows = iter([(Mystery(),)]) + try: + list(_records_array_array(rows)) + except TypeError as e: + assert "Mystery" in str(e) + else: + raise AssertionError("expected TypeError for unsupported type")