diff --git a/docs/language/reference/builders/aggregates.md b/docs/language/reference/builders/aggregates.md index 626bf94..e5e0c82 100644 --- a/docs/language/reference/builders/aggregates.md +++ b/docs/language/reference/builders/aggregates.md @@ -16,6 +16,7 @@ Current aggregate authoring is explicit and scalar-expression-based. | `avg` | `def avg(expr: ColumnExpr) -> AggregateMeasure` | Average one numeric scalar expression. | | `min` | `def min(expr: ColumnExpr) -> AggregateMeasure` | Return the minimum non-null value for one orderable scalar expression. | | `max` | `def max(expr: ColumnExpr) -> AggregateMeasure` | Return the maximum non-null value for one orderable scalar expression. | +| `approx_count_distinct` | `def approx_count_distinct(expr: ColumnExpr) -> AggregateMeasure` | Estimate distinct non-null expression values. | ## Modifiers @@ -30,7 +31,7 @@ Aggregate measures support method-style modifiers: ## Example ```incan -from pub::inql.functions import add, avg, col, count, count_distinct, count_expr, count_if, eq, lit, max, min, str_lit, sum +from pub::inql.functions import add, approx_count_distinct, avg, col, count, count_distinct, count_expr, count_if, eq, lit, max, min, str_lit, sum grouped = orders.group_by([col("customer_id")]).agg([ sum(add(col("amount"), lit(5))), @@ -42,6 +43,7 @@ grouped = orders.group_by([col("customer_id")]).agg([ avg(col("amount")), min(col("created_at")), max(col("created_at")), + approx_count_distinct(col("user_id")), ]) ``` @@ -54,5 +56,7 @@ grouped = orders.group_by([col("customer_id")]).agg([ - `count_if(predicate)` is compatibility sugar for `count().filter(predicate)`. Rows where the predicate is false or null do not contribute to the aggregate. - `sum`, `avg`, `min`, and `max` skip null values. They return backend-null results when no non-null input value exists. +- `approx_count_distinct(expr)` is approximate by contract, skips null values, allows aggregate-local filters, and rejects + an extra `distinct()` modifier because distinct estimation is already the helper's semantics. - Unsupported aggregate modifiers fail at lowering or backend planning; they are not ignored. - Future `.column` sugar and scoped aggregate symbols should lower to this same surface rather than replacing its semantics. diff --git a/docs/language/reference/functions/approximate.md b/docs/language/reference/functions/approximate.md new file mode 100644 index 0000000..4aea07a --- /dev/null +++ b/docs/language/reference/functions/approximate.md @@ -0,0 +1,31 @@ +# Approximate Functions (Reference) + +Approximate helpers are explicit opt-in functions. InQL does not silently replace exact aggregates with approximate +execution because a backend can do so. + +The current implemented slice is one aggregate: + +| Function | Meaning | +| --- | --- | +| `approx_count_distinct(expr)` | Estimate the number of distinct non-null values produced by one expression. | + +```incan +from pub::inql.functions import approx_count_distinct, col + +summary = ( + events + .group_by([col("campaign_id")]) + .agg([approx_count_distinct(col("user_id"))]) +) +``` + +`approx_count_distinct` is registered as an approximate aggregate with HyperLogLog-family metadata. The portable author +contract is an approximate non-null distinct-count estimate; the first slice does not expose a user-tunable relative +error parameter because the standard Substrait mapping for this function is unary. Backend adapters must keep this +approximation visible in capability/error handling rather than redefining exact `count_distinct` semantics. + +The helper lowers through the standard Substrait `approx_count_distinct` aggregate extension name. The DataFusion +adapter maps that declaration to DataFusion's `approx_distinct` implementation name at the backend boundary. + +Approximate percentile functions, sketch-state values, sketch serialization, and sketch merge/estimate helpers remain +future slices until their accuracy parameters, logical sketch types, and compatibility rules are explicit. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index 23ec6d0..e1056a5 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -11,10 +11,11 @@ Today the concrete shipped surfaces are documented here: - [Nested data functions](nested.md) - [Window functions](windows.md) - [Format functions](format.md) +- [Approximate functions](approximate.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, nested data, windows, and format helpers. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, nested data, windows, format helpers, and approximate aggregates. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, approximation metadata, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -42,5 +43,6 @@ The registered helper surface currently includes: | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | +| `approx_count_distinct(...)` | aggregate | approximate aggregate that lowers through the standard Substrait `approx_count_distinct` extension and is adapted to DataFusion's `approx_distinct` implementation at the backend boundary | Future ANSI-style families should grow under this section instead of bloating `dataset_types` or `dataset_methods`. diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 9be3264..d2d4ae7 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -19,6 +19,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, and `posexplode_outer(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, and lower through the current Substrait extension-relation gap encoding. - **Window functions:** RFC 019 adds the first window-function planning slice with `window()` specs, `row_number()`, `rank()`, `dense_rank()`, and `with_window_column(...)`. Ranking windows require explicit ordering and lower through Substrait `ConsistentPartitionWindowRel`; backend execution support remains a separate adapter capability. - **Format functions:** RFC 022 adds the first deterministic hashing slice with `md5(...)`, `sha224(...)`, `sha256(...)`, `sha384(...)`, `sha512(...)`, and `sha2(...)`. Hash helpers operate on UTF-8 string bytes, return lowercase hexadecimal strings, lower through registry-owned Substrait metadata, and execute through the DataFusion-backed Session path. +- **Approximate functions:** RFC 023 adds the first approximate aggregate slice with `approx_count_distinct(...)`. The helper is opt-in, marked approximate in registry metadata, lowers through the standard Substrait `approx_count_distinct` aggregate extension name, and executes through the DataFusion-backed Session path via an adapter-local mapping to DataFusion's `approx_distinct` implementation. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/023_approximate_sketch_functions.md b/docs/rfcs/023_approximate_sketch_functions.md index 4e9b925..31a301f 100644 --- a/docs/rfcs/023_approximate_sketch_functions.md +++ b/docs/rfcs/023_approximate_sketch_functions.md @@ -1,6 +1,6 @@ # InQL RFC 023: Approximate and sketch functions -- **Status:** Draft +- **Status:** In Progress - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -11,7 +11,7 @@ - InQL RFC 024 (function extension policy) - **Issue:** [InQL #40](https://github.com/dannys-code-corner/InQL/issues/40) - **RFC PR:** — -- **Written against:** Incan v0.2 +- **Written against:** Incan v0.3-era InQL - **Shipped in:** — ## Summary @@ -112,10 +112,54 @@ This RFC is additive. Existing exact aggregates must not change semantics when a - **Execution / interchange** — Prism and Substrait lowering must preserve approximate parameters, sketch state types, and merge semantics or reject unsupported functions. - **Documentation** — docs must label approximate functions clearly and explain accuracy parameters. -## Unresolved questions +## Design Decisions + +### Resolved + +- The first implementation slice is `approx_count_distinct(expr)`. It is an aggregate measure, not a scalar expression, + and its helper name makes approximate execution an explicit author choice. +- `approx_count_distinct` is registered as approximate metadata with HyperLogLog-family semantics, mergeability, and an + approximate cardinality-result interpretation. +- The first slice follows the standard Substrait unary `approx_count_distinct` aggregate mapping. It does not expose a + user-tunable relative-error parameter because the validated standard mapping does not carry one. +- DataFusion's implementation is named `approx_distinct`; InQL keeps the standard Substrait function name in emitted + function metadata and rewrites only the DataFusion consumer declaration at the backend adapter boundary. +- `approx_count_distinct` allows aggregate-local filters and rejects an extra `distinct()` modifier because distinct + estimation is already the helper's semantics. +- `approx_percentile` is not implemented in this slice because the local Substrait aggregate-approx extension has a + standard `approx_count_distinct` mapping but no matching standard approximate percentile contract to preserve. +- Sketch-state construction, merge, estimate, serialization, and deserialization helpers remain future work until InQL + has explicit sketch logical types and compatibility rules. + +### Remaining - Should InQL standardize one sketch family per use case or expose multiple named families? - What serialization format, if any, should be portable across backends? - How should accuracy guarantees be documented without implying backend-independent statistical promises that are not true? - - +- Should future approximate aggregates expose user-tunable accuracy parameters through aggregate options, option records, + or separate helper names when Substrait has no standard parameter slot? +- Which approximate percentile family should become the portable core contract, and how should percentile domain, + interpolation, and accuracy be specified? + +## Implementation Plan + +1. Add registry approximation metadata with exact-helper defaults. +2. Add `approx_count_distinct(expr)` under a logical approximate function family. +3. Add a stable Substrait anchor and keep emitted function metadata on the standard `approx_count_distinct` name. +4. Add a DataFusion adapter-local rewrite to `approx_distinct` for the first backend. +5. Add focused helper, registry, Substrait lowering, Prism, and DataFusion-backed session tests with materialized output. +6. Add user-facing approximate-function docs, aggregate-builder docs, and release notes. +7. Leave approximate percentile and sketch-state helpers for later RFC 023 slices once remaining contracts are resolved. + +## Progress Checklist + +- [x] RFC 023 moved to In Progress with a first implementation slice and recorded design decisions. +- [x] Registry approximation metadata added for intentionally approximate functions. +- [x] `approx_count_distinct` helper added under the function catalog. +- [x] Standard Substrait `approx_count_distinct` extension metadata added. +- [x] DataFusion adapter-local `approx_count_distinct` to `approx_distinct` mapping added. +- [x] Focused helper, registry, Substrait lowering, Prism, and DataFusion-backed session tests added. +- [x] User-facing approximate-function docs, aggregate-builder docs, and release notes added. +- [ ] Approximate percentile semantics specified and implemented. +- [ ] Sketch-state logical types specified and implemented. +- [ ] Sketch merge, estimate, serialize, and deserialize helpers specified and implemented. diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index 290c162..d0ff2ad 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -29,7 +29,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [020][rfc-020] | Draft | Nested data functions | | | [021][rfc-021] | In Progress | Generator and table-valued functions | | | [022][rfc-022] | In Progress | Semi-structured and format functions | | -| [023][rfc-023] | Draft | Approximate and sketch functions | | +| [023][rfc-023] | In Progress | Approximate and sketch functions | | | [024][rfc-024] | Draft | Function extension policy | | diff --git a/src/aggregate_builders.incn b/src/aggregate_builders.incn index 755f66b..2e581a9 100644 --- a/src/aggregate_builders.incn +++ b/src/aggregate_builders.incn @@ -19,6 +19,7 @@ pub enum AggregateKind(str): Avg = "avg" Min = "min" Max = "max" + ApproxCountDistinct = "approx_count_distinct" @derive(Clone) @@ -137,3 +138,8 @@ pub def min(expr: ColumnExpr) -> AggregateMeasure: pub def max(expr: ColumnExpr) -> AggregateMeasure: """Build one `max` aggregate measure over a scalar expression.""" return _aggregate_measure("max", AggregateKind.Max, expr, true) + + +pub def approx_count_distinct(expr: ColumnExpr) -> AggregateMeasure: + """Build one approximate distinct-count aggregate measure over a scalar expression.""" + return _aggregate_measure("approx_count_distinct", AggregateKind.ApproxCountDistinct, expr, true) diff --git a/src/function_registry.incn b/src/function_registry.incn index 2ac97ff..5124b30 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -128,6 +128,17 @@ pub model AggregateModifierPolicy: pub allows_ordered_input: bool +@derive(Clone) +pub model FunctionApproximationPolicy: + """Approximation contract metadata for functions with intentionally approximate results.""" + + pub is_approximate: bool + pub algorithm: str + pub default_error: str + pub mergeable: bool + pub result_interpretation: str + + @derive(Clone) pub model SubstraitMapping: """Portable interchange mapping metadata for one registered function.""" @@ -154,6 +165,7 @@ pub model FunctionSpec: pub null_behavior: FunctionNullBehavior pub error_behavior: FunctionErrorBehavior pub aggregate_modifiers: AggregateModifierPolicy + pub approximation: FunctionApproximationPolicy pub substrait: SubstraitMapping @@ -173,6 +185,7 @@ pub model FunctionRegistryEntry: pub null_behavior: FunctionNullBehavior pub error_behavior: FunctionErrorBehavior pub aggregate_modifiers: AggregateModifierPolicy + pub approximation: FunctionApproximationPolicy pub substrait: SubstraitMapping @@ -226,6 +239,7 @@ pub class FunctionRegistry: null_behavior=spec.null_behavior, error_behavior=spec.error_behavior, aggregate_modifiers=spec.aggregate_modifiers, + approximation=spec.approximation, substrait=spec.substrait, ), ) @@ -378,6 +392,11 @@ pub def core_aggregate_modifier_policy() -> AggregateModifierPolicy: return aggregate_modifier_policy(true, true, false) +pub def approximate_distinct_aggregate_modifier_policy() -> AggregateModifierPolicy: + """Return the modifier policy for aggregates whose semantics already include distinct estimation.""" + return aggregate_modifier_policy(false, true, false) + + def _aggregate_modifier_policy_for_class(function_class: FunctionClass) -> AggregateModifierPolicy: """Return the default modifier policy for a semantic function class.""" if function_class == FunctionClass.Aggregate: @@ -385,6 +404,38 @@ def _aggregate_modifier_policy_for_class(function_class: FunctionClass) -> Aggre return no_aggregate_modifiers() +pub def approximation_policy( + is_approximate: bool, + algorithm: str, + default_error: str, + mergeable: bool, + result_interpretation: str, +) -> FunctionApproximationPolicy: + """Build one explicit approximation metadata record.""" + return FunctionApproximationPolicy( + is_approximate=is_approximate, + algorithm=algorithm, + default_error=default_error, + mergeable=mergeable, + result_interpretation=result_interpretation, + ) + + +pub def no_approximation_policy() -> FunctionApproximationPolicy: + """Return the approximation metadata used by exact or structural helpers.""" + return approximation_policy(false, "", "", false, "") + + +pub def fixed_approximation_policy( + algorithm: str, + default_error: str, + mergeable: bool, + result_interpretation: str, +) -> FunctionApproximationPolicy: + """Return approximation metadata for helpers with a fixed portable author contract.""" + return approximation_policy(true, algorithm, default_error, mergeable, result_interpretation) + + pub def function_policy_spec( namespace: str, policy_category: FunctionPolicyCategory, @@ -396,6 +447,7 @@ pub def function_policy_spec( null_behavior: FunctionNullBehavior, error_behavior: FunctionErrorBehavior, aggregate_modifiers: AggregateModifierPolicy, + approximation: FunctionApproximationPolicy, substrait: SubstraitMapping, ) -> FunctionSpec: """Build one function spec with explicit RFC 024 namespace and policy metadata.""" @@ -410,6 +462,7 @@ pub def function_policy_spec( null_behavior=null_behavior, error_behavior=error_behavior, aggregate_modifiers=aggregate_modifiers, + approximation=approximation, substrait=substrait, ) @@ -432,6 +485,31 @@ pub def deterministic_spec( null_behavior=null_behavior, error_behavior=FunctionErrorBehavior.Typechecked, aggregate_modifiers=_aggregate_modifier_policy_for_class(function_class), + approximation=no_approximation_policy(), + substrait=substrait, + ) + + +pub def approximate_aggregate_spec( + lifecycle: FunctionLifecycle, + null_behavior: FunctionNullBehavior, + aggregate_modifiers: AggregateModifierPolicy, + approximation: FunctionApproximationPolicy, + substrait: SubstraitMapping, +) -> FunctionSpec: + """Build one portable deterministic aggregate spec for intentionally approximate results.""" + return FunctionSpec( + namespace=core_function_namespace(), + policy_category=FunctionPolicyCategory.PortableCore, + function_class=FunctionClass.Aggregate, + aliases=[], + alias_policy=FunctionAliasPolicy.CoreImport, + lifecycle=lifecycle, + determinism=FunctionDeterminism.Deterministic, + null_behavior=null_behavior, + error_behavior=FunctionErrorBehavior.Typechecked, + aggregate_modifiers=aggregate_modifiers, + approximation=approximation, substrait=substrait, ) @@ -457,6 +535,7 @@ pub def extension_only_spec( null_behavior, error_behavior, no_aggregate_modifiers(), + no_approximation_policy(), substrait, ) @@ -483,6 +562,7 @@ pub def compatibility_alias_spec( null_behavior, error_behavior, _aggregate_modifier_policy_for_class(function_class), + no_approximation_policy(), substrait, ) @@ -508,6 +588,7 @@ pub def engine_specific_spec( null_behavior, error_behavior, no_aggregate_modifiers(), + no_approximation_policy(), substrait, ) diff --git a/src/functions/approximate/approx_count_distinct.incn b/src/functions/approximate/approx_count_distinct.incn new file mode 100644 index 0000000..e5e70b7 --- /dev/null +++ b/src/functions/approximate/approx_count_distinct.incn @@ -0,0 +1,54 @@ +""" +Approximate distinct-count aggregate helper. + +`approx_count_distinct` estimates the number of distinct non-null values produced by one expression. +""" + +from aggregate_builders import AggregateKind, AggregateMeasure, approx_count_distinct as approx_count_distinct_builder +from function_registry import ( + FunctionLifecycle, + FunctionNullBehavior, + approximate_aggregate_spec, + approximate_distinct_aggregate_modifier_policy, + extension_mapping, + fixed_approximation_policy, + v0_1, +) +from functions.registry import function_registry +from projection_builders import ColumnExpr, col, column_expr_name +from substrait.function_extensions import APPROX_COUNT_DISTINCT_FUNCTION_ANCHOR + + +@function_registry.add("approx_count_distinct", approximate_aggregate_spec( + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NullSkippingAggregate, + approximate_distinct_aggregate_modifier_policy(), + fixed_approximation_policy( + "HyperLogLog", + "implementation-defined fixed HyperLogLog precision", + true, + "approximate non-null distinct-count cardinality estimate", + ), + extension_mapping("approx_count_distinct", APPROX_COUNT_DISTINCT_FUNCTION_ANCHOR), +)) +pub def approx_count_distinct(expr: ColumnExpr) -> AggregateMeasure: + """ + Estimate distinct non-null values produced by one expression. + + Examples: + active_users = approx_count_distinct(col("user_id")) + + Parameters: + expr: Expression whose distinct non-null values should be estimated. + """ + return approx_count_distinct_builder(expr) + + +module tests: + def test_approx_count_distinct_builds_approximate_aggregate_measure() -> None: + measure = approx_count_distinct(col("user_id")) + assert measure.kind == AggregateKind.ApproxCountDistinct + assert measure.has_expr + assert not measure.is_distinct + assert not measure.has_filter + assert column_expr_name(measure.expr) == "user_id" diff --git a/src/functions/approximate/mod.incn b/src/functions/approximate/mod.incn new file mode 100644 index 0000000..fdd821e --- /dev/null +++ b/src/functions/approximate/mod.incn @@ -0,0 +1,3 @@ +"""Approximate aggregate helpers.""" + +pub from functions.approximate.approx_count_distinct import approx_count_distinct diff --git a/src/functions/mod.incn b/src/functions/mod.incn index 6652a8f..aa07cad 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -35,6 +35,7 @@ pub from functions.aggregates.sum import sum pub from functions.aggregates.avg import avg pub from functions.aggregates.min import min pub from functions.aggregates.max import max +pub from functions.approximate.approx_count_distinct import approx_count_distinct pub from functions.math.abs import abs pub from functions.math.ceil import ceil pub from functions.math.floor import floor diff --git a/src/lib.incn b/src/lib.incn index 2b6670e..c4683a0 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -81,6 +81,7 @@ pub from functions.aggregates.sum import sum pub from functions.aggregates.avg import avg pub from functions.aggregates.min import min pub from functions.aggregates.max import max +pub from functions.approximate.approx_count_distinct import approx_count_distinct pub from functions.math.abs import abs pub from functions.math.ceil import ceil pub from functions.math.floor import floor @@ -165,6 +166,7 @@ pub from function_registry import ( FunctionLifecycle, FunctionNullBehavior, FunctionPolicyCategory, + FunctionApproximationPolicy, FunctionRegistry, FunctionRegistryEntry, FunctionSpec, @@ -172,15 +174,20 @@ pub from function_registry import ( RejectedFunctionPolicy, SubstraitMapping, SubstraitMappingKind, + approximate_aggregate_spec, + approximate_distinct_aggregate_modifier_policy, + approximation_policy, compatibility_alias_spec, core_function_namespace, deterministic_spec, engine_specific_spec, extension_only_spec, extension_mapping, + fixed_approximation_policy, function_ref_for, function_policy_spec, namespaced_function_ref, + no_approximation_policy, rejected_function_policy, relation_extension_mapping, rewrite_mapping, diff --git a/src/session/datafusion_backend.incn b/src/session/datafusion_backend.incn index eb13a7b..9087a84 100644 --- a/src/session/datafusion_backend.incn +++ b/src/session/datafusion_backend.incn @@ -10,11 +10,16 @@ from rust::datafusion::prelude import CsvReadOptions, ParquetReadOptions from rust::datafusion::dataframe import DataFrameWriteOptions from rust::datafusion_substrait::substrait::proto import Plan as ConsumerPlan from rust::datafusion_substrait::logical_plan::consumer import from_substrait_plan +from rust::substrait::proto::extensions import SimpleExtensionDeclaration +from rust::substrait::proto::extensions::simple_extension_declaration import ExtensionFunction, MappingType from backends import SourceKind, TableSource from dataset.materialization import DataFrameMaterialization from session.backend_types import BackendError, BackendErrorKind, BackendRegistration, backend_error from substrait.inspect import root_names +const SUBSTRAIT_APPROX_COUNT_DISTINCT_FUNCTION_NAME: str = "approx_count_distinct" +const DATAFUSION_APPROX_COUNT_DISTINCT_FUNCTION_NAME: str = "approx_distinct" + @derive(Clone) enum DataFusionSourceRegistration(str): @@ -165,12 +170,56 @@ def datafusion_registration_for_source(source: TableSource) -> DataFusionSourceR def _consumer_plan_from_current_plan(plan: Plan) -> Result[ConsumerPlan, BackendError]: """Decode the producer-side plan bytes into DataFusion's consumer Plan type.""" - encoded = plan.encode_to_vec() + encoded = _datafusion_producer_plan(plan).encode_to_vec() match ConsumerPlan.decode(encoded.as_slice()): Ok(decoded) => return Ok(decoded) Err(err) => return Err(backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) +def _datafusion_producer_plan(plan: Plan) -> Plan: + """Return a producer plan with extension names adapted for the DataFusion consumer only.""" + # Rebuild the producer-side plan before encoding so InQL's emitted plan keeps standard Substrait names. + # The DataFusion consumer re-export currently does not expose extension declarations to Incan metadata. + return Plan( + version=plan.version, + extension_urns=plan.extension_urns, + extensions=[_datafusion_extension_declaration(extension) for extension in plan.extensions], + relations=plan.relations, + advanced_extensions=plan.advanced_extensions, + expected_type_urls=plan.expected_type_urls, + parameter_bindings=plan.parameter_bindings, + type_aliases=plan.type_aliases, + ) + + +def _datafusion_extension_declaration(extension: SimpleExtensionDeclaration) -> SimpleExtensionDeclaration: + """Return one extension declaration adapted to DataFusion implementation names when required.""" + match extension.mapping_type: + Some(MappingType.ExtensionFunction(function)) => + if function.name == SUBSTRAIT_APPROX_COUNT_DISTINCT_FUNCTION_NAME: + return _renamed_datafusion_extension_function(function, DATAFUSION_APPROX_COUNT_DISTINCT_FUNCTION_NAME) + return SimpleExtensionDeclaration(mapping_type=Some(MappingType.ExtensionFunction(function))) + _ => return extension + + +def _renamed_datafusion_extension_function( + function: ExtensionFunction, + datafusion_name: str, +) -> SimpleExtensionDeclaration: + """Return one extension-function declaration preserving anchor and URN while changing the consumer name.""" + return SimpleExtensionDeclaration( + mapping_type=Some( + MappingType.ExtensionFunction( + ExtensionFunction( + extension_urn_reference=function.extension_urn_reference, + function_anchor=function.function_anchor, + name=datafusion_name, + ), + ), + ), + ) + + def _rust_usize_to_int(value: RustUsize) -> Result[int, BackendError]: """Convert one Rust `usize` count into Incan `int` through a checked `i64` boundary.""" match RustI64.try_from(value): diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index fa9cfed..a7804d6 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -84,6 +84,7 @@ pub const SHA224_FUNCTION_ANCHOR: u32 = 56 pub const SHA256_FUNCTION_ANCHOR: u32 = 57 pub const SHA384_FUNCTION_ANCHOR: u32 = 58 pub const SHA512_FUNCTION_ANCHOR: u32 = 59 +pub const APPROX_COUNT_DISTINCT_FUNCTION_ANCHOR: u32 = 60 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" diff --git a/tests/test_approximate_functions.incn b/tests/test_approximate_functions.incn new file mode 100644 index 0000000..c2a4863 --- /dev/null +++ b/tests/test_approximate_functions.incn @@ -0,0 +1,80 @@ +"""Test: RFC 023 approximate aggregate helper surface.""" + +from std.testing import assert_is_err, fail_t +from functions import approx_count_distinct, asc, col, eq, str_lit +from aggregate_builders import AggregateKind +from function_registry import function_ref_for +from projection_builders import column_expr_name +from substrait.errors import SubstraitLoweringErrorKind +from substrait.relations import read_named_table_rel, try_aggregate_rel_of_columns +from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind +from substrait.schema_registry import register_named_table_schema + + +def _register_approx_orders_schema() -> None: + """Register a small table schema for aggregate lowering tests.""" + register_named_table_schema( + "approx_orders", + [RowColumnSpec(name="customer_id", kind=SubstraitPrimitiveKind.String, nullable=false), RowColumnSpec( + name="product_id", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) + + +def test_approximate_functions__approx_count_distinct_builds_aggregate_measure() -> None: + # -- Arrange -- + product_id = col("product_id") + + # -- Act -- + measure = approx_count_distinct(product_id) + + # -- Assert -- + assert measure.kind == AggregateKind.ApproxCountDistinct, "approx_count_distinct should use its own aggregate kind" + assert measure.function_ref == function_ref_for("approx_count_distinct"), "measure should preserve registry identity" + assert measure.canonical_name == "approx_count_distinct", "measure should preserve canonical helper name" + assert measure.has_expr, "approx_count_distinct should require an input expression" + assert not measure.is_distinct, "approx_count_distinct should not add a distinct modifier" + assert not measure.has_filter, "approx_count_distinct should start without a filter" + assert len(measure.ordering) == 0, "approx_count_distinct should start without ordered input" + assert column_expr_name(measure.expr) == "product_id", "measure should preserve the input expression" + + +def test_approximate_functions__approx_count_distinct_allows_filter_but_not_distinct_or_ordered_input() -> None: + # -- Arrange -- + _register_approx_orders_schema() + base = read_named_table_rel("approx_orders") + product_id = col("product_id") + paid = eq(col("customer_id"), str_lit("A")) + + # -- Act -- + filtered_result = try_aggregate_rel_of_columns( + base.clone(), + ["customer_id", "product_id"], + [col("customer_id")], + [approx_count_distinct(product_id.clone()).filter(paid)], + ) + distinct_result = try_aggregate_rel_of_columns( + base.clone(), + ["customer_id", "product_id"], + [col("customer_id")], + [approx_count_distinct(product_id.clone()).distinct()], + ) + ordered_result = try_aggregate_rel_of_columns( + base, + ["customer_id", "product_id"], + [col("customer_id")], + [approx_count_distinct(product_id).order_by([asc(col("product_id"))])], + ) + + # -- Assert -- + match filtered_result: + Ok(_) => pass + Err(err) => return fail_t(err.error_message()) + distinct_err = assert_is_err(distinct_result, "approx_count_distinct should reject extra DISTINCT") + ordered_err = assert_is_err(ordered_result, "approx_count_distinct should reject ordered input") + assert distinct_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "distinct rejection should be a lowering diagnostic" + assert ordered_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "ordered-input rejection should be a lowering diagnostic" + assert distinct_err.message.contains("DISTINCT"), "distinct rejection should identify the modifier" + assert ordered_err.message.contains("ordered aggregate input"), "ordered rejection should identify the modifier" diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 197c275..8b68e24 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -8,6 +8,7 @@ from functions import ( always_false, always_true, and_, + approx_count_distinct, asc, asc_nulls_first, asc_nulls_last, @@ -129,6 +130,7 @@ from substrait.function_extensions import ( ABS_FUNCTION_ANCHOR, ADD_FUNCTION_ANCHOR, AND_FUNCTION_ANCHOR, + APPROX_COUNT_DISTINCT_FUNCTION_ANCHOR, ARRAY_DISTINCT_FUNCTION_ANCHOR, ARRAY_ELEMENT_FUNCTION_ANCHOR, ARRAY_EXCEPT_FUNCTION_ANCHOR, @@ -249,12 +251,12 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "window", "row_number", "rank", "dense_rank", "sha224", "sha256", "sha384", "sha512", "sha2", "md5"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "approx_count_distinct", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "window", "row_number", "rank", "dense_rank", "sha224", "sha256", "sha384", "sha512", "sha2", "md5"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "row_number", "rank", "dense_rank", "sha224", "sha256", "sha384", "sha512", "md5"] + return ["sum", "count", "count_expr", "avg", "min", "max", "approx_count_distinct", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "row_number", "rank", "dense_rank", "sha224", "sha256", "sha384", "sha512", "md5"] def _exercise_current_public_helpers() -> None: @@ -282,6 +284,7 @@ def _exercise_current_public_helpers() -> None: avg(amount) min(amount) max(amount) + approx_count_distinct(status) add(amount, lit(1)) mul(amount, int_lit(2)) eq(status, str_lit("paid")) @@ -490,6 +493,7 @@ def test_function_registry__aggregate_helpers_expose_modifier_policy() -> None: # -- Act -- count_entry = _entry_by_name_or_fail("count") sum_entry = _entry_by_name_or_fail("sum") + approx_entry = _entry_by_name_or_fail("approx_count_distinct") abs_entry = _entry_by_name_or_fail("abs") # -- Assert -- @@ -498,10 +502,30 @@ def test_function_registry__aggregate_helpers_expose_modifier_policy() -> None: assert not count_entry.aggregate_modifiers.allows_ordered_input, "core aggregates should reject ordered input until an order-sensitive aggregate lands" assert sum_entry.aggregate_modifiers.allows_distinct, "numeric aggregates should allow DISTINCT" assert sum_entry.aggregate_modifiers.allows_filter, "numeric aggregates should allow aggregate-local filters" + assert not approx_entry.aggregate_modifiers.allows_distinct, "approximate distinct-count semantics should reject an extra DISTINCT modifier" + assert approx_entry.aggregate_modifiers.allows_filter, "approximate distinct-count should allow aggregate-local filters" + assert not approx_entry.aggregate_modifiers.allows_ordered_input, "approximate distinct-count should reject ordered input" assert not abs_entry.aggregate_modifiers.allows_distinct, "scalar helpers should not expose aggregate modifier support" assert not abs_entry.aggregate_modifiers.allows_filter, "scalar helpers should not expose aggregate modifier support" +def test_function_registry__approximate_helpers_expose_approximation_policy() -> None: + """Assert approximate helpers are explicit registry metadata, not hidden backend choices.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act -- + approx_entry = _entry_by_name_or_fail("approx_count_distinct") + count_entry = _entry_by_name_or_fail("count") + + # -- Assert -- + assert approx_entry.approximation.is_approximate, "approx_count_distinct should be explicitly approximate" + assert approx_entry.approximation.algorithm == "HyperLogLog", "approx_count_distinct should expose its sketch family" + assert approx_entry.approximation.mergeable, "approx_count_distinct should be marked mergeable" + assert approx_entry.approximation.result_interpretation.contains("cardinality"), "approx_count_distinct should describe its result" + assert not count_entry.approximation.is_approximate, "exact count should not inherit approximation metadata" + + def test_function_registry__extension_policy_is_separate_from_scalar_class() -> None: """Assert extension-only functions can be scalar while remaining explicitly namespaced.""" # -- Arrange -- @@ -608,6 +632,7 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("avg", "avg", AVG_FUNCTION_ANCHOR) _assert_extension_mapping("min", "min", MIN_FUNCTION_ANCHOR) _assert_extension_mapping("max", "max", MAX_FUNCTION_ANCHOR) + _assert_extension_mapping("approx_count_distinct", "approx_count_distinct", APPROX_COUNT_DISTINCT_FUNCTION_ANCHOR) _assert_extension_mapping("add", "add", ADD_FUNCTION_ANCHOR) _assert_extension_mapping("mul", "multiply", MULTIPLY_FUNCTION_ANCHOR) _assert_extension_mapping("eq", "equal", EQUAL_FUNCTION_ANCHOR) @@ -762,6 +787,7 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: avg_measure = avg(amount) min_measure = min(amount) max_measure = max(amount) + approx_count_distinct_measure = approx_count_distinct(status) add_expr = add(amount, lit(7)) eq_expr = eq(status, str_lit("paid")) gt_expr = gt(amount, int_lit(10)) @@ -792,6 +818,7 @@ def test_function_registry__public_helpers_preserve_existing_behavior() -> None: assert avg_measure.kind == AggregateKind.Avg, "avg wrapper should preserve aggregate kind" assert min_measure.kind == AggregateKind.Min, "min wrapper should preserve aggregate kind" assert max_measure.kind == AggregateKind.Max, "max wrapper should preserve aggregate kind" + assert approx_count_distinct_measure.kind == AggregateKind.ApproxCountDistinct, "approx_count_distinct wrapper should preserve aggregate kind" assert column_expr_kind(add_expr) == ColumnExprKind.ScalarFunction, "add should use the shared scalar function kind" assert column_expr_kind(eq_expr) == ColumnExprKind.ScalarFunction, "eq should use the shared scalar function kind" assert column_expr_kind(gt_expr) == ColumnExprKind.ScalarFunction, "gt should use the shared scalar function kind" diff --git a/tests/test_prism.incn b/tests/test_prism.incn index 16b8564..3b0bf6d 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,6 +1,19 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, row_number, sum, window +from functions import ( + approx_count_distinct, + always_false, + always_true, + col, + count, + count_expr, + explode, + lit, + mul, + row_number, + sum, + window, +) from prism import ( PrismCursor, prism_cursor_apply_filter, @@ -323,6 +336,22 @@ def test_prism__rewrite_collapses_adjacent_aggregates_by_merging_measures() -> N assert relation_kind_name(root_rel(plan)) == str("AggregateRel"), "collapsed aggregate rewrite should still lower through AggregateRel" +def test_prism__aggregate_output_columns_include_approximate_measures() -> None: + # -- Arrange -- + _register_projection_test_schema(str("orders")) + aggregated: PrismCursor[Order] = prism_cursor_named_table(str("orders")).group_by([col("id")]).agg( + [approx_count_distinct(col("id"))], + ) + + # -- Act -- + output_cols = prism_cursor_output_columns(aggregated) + plan = aggregated.to_substrait_plan() + + # -- Assert -- + assert output_cols == ["id", "approx_count_distinct_id"], "approximate aggregate output columns should stay stable through Prism" + assert relation_kind_name(root_rel(plan)) == str("AggregateRel"), "approximate aggregate should lower through AggregateRel" + + def test_prism__with_column_tracks_output_columns_and_collapses_adjacent_projects() -> None: # -- Arrange -- _register_projection_test_schema("orders_projection_prism") diff --git a/tests/test_session_aggregates.incn b/tests/test_session_aggregates.incn index 871bf6b..4bc14ac 100644 --- a/tests/test_session_aggregates.incn +++ b/tests/test_session_aggregates.incn @@ -1,6 +1,19 @@ """End-to-end Session aggregate execution tests over the DataFusion backend.""" -from functions import avg, col, count, count_distinct, count_expr, count_if, eq, max, min, str_lit, sum +from functions import ( + approx_count_distinct, + avg, + col, + count, + count_distinct, + count_expr, + count_if, + eq, + max, + min, + str_lit, + sum, +) from dataset import DataFrame, LazyFrame from session import Session from std.testing import assert_is_ok, fail_t @@ -127,6 +140,30 @@ def test_session_aggregates__grouped_collect_executes_distinct_and_filter_modifi assert not payload.contains("3"), "distinct counts should not count duplicated product ids" +def test_session_aggregates__grouped_collect_executes_approx_count_distinct() -> None: + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateModifierOrder] = assert_is_ok( + session.read_csv("aggregate_modifiers", AGGREGATE_MODIFIERS_CSV_FIXTURE), + "aggregate modifiers fixture should load", + ) + paid = eq(col("status"), str_lit("paid")) + grouped = lazy.group_by([col("customer_id")]).agg([approx_count_distinct(col("product_id")).filter(paid)]) + df = _collect_modifier_or_fail(session, grouped) + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 2, "approximate grouped aggregate should produce one row per customer" + assert resolved == ["customer_id", "approx_count_distinct_product_id_filtered"], "approximate aggregate output columns should be stable" + assert payload.contains("A"), "approximate aggregate output should contain customer A" + assert payload.contains("B"), "approximate aggregate output should contain customer B" + assert payload.contains("1"), "customer A paid product estimate should materialize" + assert payload.contains("2"), "customer B paid product estimate should materialize" + + def test_session_aggregates__global_collect_executes_count() -> None: # -- Arrange -- mut session = Session.default() diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 0ec722b..bd7ffcc 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -6,6 +6,7 @@ from functions import ( add, always_true, and_, + approx_count_distinct, array, array_contains, array_distinct, @@ -566,6 +567,26 @@ def test_plan__aggregate_rel_lowers_distinct_and_filter_modifiers() -> None: assert aggregate_sort_counts == [0, 0, 0], "core aggregates should not emit ordered input fields" +def test_plan__aggregate_rel_lowers_approx_count_distinct() -> None: + # -- Arrange -- + _register_orders_schema() + base = read_named_table_rel("orders") + aggregated = aggregate_rel(base, [col("id")], [approx_count_distinct(col("id"))]) + plan = plan_from_root_relation(aggregated, ["id", "approx_count_distinct_id"]) + + # -- Act -- + output_columns = relation_output_columns(aggregated) + aggregate_functions = aggregate_measure_function_names(aggregated) + aggregate_invocations = aggregate_measure_invocation_names(aggregated) + + # -- Assert -- + assert relation_kind_name(aggregated) == "AggregateRel", "approximate aggregate lowering should emit AggregateRel" + assert plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), "approximate aggregate plans should register aggregate extensions" + assert output_columns == ["id", "approx_count_distinct_id"], "approximate aggregate output names should remain stable" + assert aggregate_functions == ["approx_count_distinct"], "approx_count_distinct should keep its standard Substrait function name" + assert aggregate_invocations == ["All"], "approx_count_distinct should not lower as an exact DISTINCT invocation" + + def test_plan__aggregate_rel_rejects_invalid_modifier_shapes() -> None: # -- Arrange -- _register_orders_schema() @@ -573,6 +594,12 @@ def test_plan__aggregate_rel_rejects_invalid_modifier_shapes() -> None: # -- Act -- distinct_count_result = try_aggregate_rel_of_columns(base.clone(), ["id"], [col("id")], [count().distinct()]) + distinct_approx_result = try_aggregate_rel_of_columns( + base.clone(), + ["id"], + [col("id")], + [approx_count_distinct(col("id")).distinct()], + ) ordered_sum_result = try_aggregate_rel_of_columns( base, ["id"], @@ -580,11 +607,14 @@ def test_plan__aggregate_rel_rejects_invalid_modifier_shapes() -> None: [sum(col("id")).order_by([asc(col("id"))])], ) distinct_err = assert_is_err(distinct_count_result, "count().distinct() should require an input expression") + distinct_approx_err = assert_is_err(distinct_approx_result, "approx_count_distinct should reject extra DISTINCT") ordered_err = assert_is_err(ordered_sum_result, "core aggregates should reject ordered input for now") # -- Assert -- assert distinct_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "invalid distinct count should be a scalar lowering diagnostic" assert distinct_err.message.contains("DISTINCT"), "invalid distinct diagnostic should mention DISTINCT" + assert distinct_approx_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "invalid approximate distinct should be a scalar lowering diagnostic" + assert distinct_approx_err.message.contains("DISTINCT"), "approximate distinct diagnostic should mention DISTINCT" assert ordered_err.kind == SubstraitLoweringErrorKind.InvalidScalarExpression, "unsupported ordered aggregate should be a scalar lowering diagnostic" assert ordered_err.message.contains("ordered aggregate input"), "ordered aggregate diagnostic should identify the unsupported modifier"