diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index ab4926e..9a9a701 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -19,9 +19,10 @@ The Substrait helper surface behind these methods is split by semantic role: | `with_column` | `def with_column(self, name: str, expr: ColumnExpr) -> Self` | Add or replace one projected column using a scalar expression. | | `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | | `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | +| `generate` | `def generate(self, generator: GeneratorApplication) -> Self` | Apply a relation-shaping generator such as `explode(...)` with explicit output aliases. | | `order_by` | `def order_by(self, columns: list[ColumnExpr]) -> Self` | Sort rows by scalar expressions or ordering helpers such as `asc(...)` and `desc(...)`. | | `limit` | `def limit(self, n: int) -> Self` | Cap row count. | -| `explode` | `def explode(self) -> Self` | Expand a nested list column into rows. | +| `explode` | `def explode(self) -> Self` | Compatibility marker for the older EXPLODE extension path. Prefer `generate(explode(...))`. | ## `with_column` @@ -67,6 +68,7 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - `join(...)` is constrained to same-carrier inputs and the boolean join predicate surface shown in the signature. - `select(...)` preserves projection shape; explicit projection lists are represented today through `with_column(...)` and scalar-expression builders. +- `generate(...)` preserves all input columns and appends generated output aliases. Alias collisions are rejected during planning/lowering. - `DataFrame[T]` exposes materialized metadata and preview text; row-level accessors belong to the materialized DataFrame API surface. - Query-block and scoped DSL surfaces lower into these builder APIs rather than defining separate method semantics. diff --git a/docs/language/reference/functions/generators.md b/docs/language/reference/functions/generators.md new file mode 100644 index 0000000..844cb20 --- /dev/null +++ b/docs/language/reference/functions/generators.md @@ -0,0 +1,32 @@ +# Generator and Table-Valued Functions (Reference) + +Generators are relation-shaping operations. They are registry-backed like scalar and aggregate helpers, but they return +`GeneratorApplication` values and must be applied through a relation method such as `generate(...)`. + +```incan +from pub::inql import LazyFrame +from pub::inql.functions import col, explode +from models import Order + +def order_lines(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.generate(explode(col("line_items"), "line_item")) +``` + +The explicit generator surface currently includes: + +| Function | Output aliases | Relation effect | +| --- | --- | --- | +| `explode(expr, as_)` | one value column | Emits one row per array element; null or empty inputs emit zero rows. | +| `explode_outer(expr, as_)` | one value column | Preserves the input row for null or empty inputs and emits a null generated value. | +| `posexplode(expr, position_as, value_as)` | position and value columns | Emits one row per array element with a zero-based position column. | +| `posexplode_outer(expr, position_as, value_as)` | position and value columns | Outer positional explode with the same zero-based position rule. | + +Generator applications preserve input columns and append generated columns in declaration order. Generated aliases are +required, must be non-empty, and must not collide with existing input columns. + +The older zero-argument `DataSet.explode()` method remains available as a compatibility marker for the current Substrait +extension relation gap. New code should prefer `generate(explode(...))` so the relation-shaping function identity and +output schema are explicit. + +Nested scalar helpers such as `array_flatten(...)` remain scalar expressions. They do not expand rows and are documented +on the [nested data functions](nested.md) page. diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index f6347a8..e65ea90 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -7,11 +7,12 @@ Today the concrete shipped surfaces are documented here: - [Filter builders](../builders/filters.md) - [Aggregate builders](../builders/aggregates.md) - [Projection builders](../builders/projections.md) +- [Generator and table-valued functions](generators.md) - [Nested data functions](nested.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -33,6 +34,7 @@ The registered helper surface currently includes: | `in_(...)`, `between(...)` | scalar | built-in membership/range lowering (`SingularOrList` and `between`) | | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | | `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | +| `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)` | generator | relation-extension mappings consumed by `generate(...)`; positional forms use zero-based positions | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | diff --git a/docs/language/reference/substrait/operator_catalog.md b/docs/language/reference/substrait/operator_catalog.md index 4560185..327ad49 100644 --- a/docs/language/reference/substrait/operator_catalog.md +++ b/docs/language/reference/substrait/operator_catalog.md @@ -81,6 +81,9 @@ Core Substrait does not define a portable unnest or explode `Rel` at the logical Current package-level RFC 002 boundary registration: - `https://inql.io/extensions/v0.1/unnest.yaml#explode` +- `https://inql.io/extensions/v0.1/unnest.yaml#explode_outer` +- `https://inql.io/extensions/v0.1/unnest.yaml#posexplode` +- `https://inql.io/extensions/v0.1/unnest.yaml#posexplode_outer` ### Pivot / unpivot diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 2543685..5c23085 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -16,6 +16,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Core scalar functions:** RFC 015 adds registry-backed scalar function applications and the first core helper slice for casts, comparisons, boolean logic, null/NaN predicates, arithmetic, conditionals, membership/range predicates, and ordering expressions. Implemented helpers lower to Substrait IR through registry metadata, built-in Rex shapes, or structural sort-field lowering; DataFusion remains the first execution adapter rather than the semantic boundary. - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. - **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata and execute through the DataFusion-backed Session path without introducing generator semantics. +- **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, and `posexplode_outer(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, and lower through the current Substrait extension-relation gap encoding. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/021_generator_table_functions.md b/docs/rfcs/021_generator_table_functions.md index b33febb..ad0039c 100644 --- a/docs/rfcs/021_generator_table_functions.md +++ b/docs/rfcs/021_generator_table_functions.md @@ -1,6 +1,6 @@ # InQL RFC 021: Generator and table-valued functions -- **Status:** Draft +- **Status:** In Progress - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -42,14 +42,15 @@ InQL already has an unnest/explode design direction through its Substrait work. ## Guide-level explanation (how authors think about it) -Authors should use generators when one input row may become multiple output rows: +Authors should use generators when one input row may become multiple output rows. In the current builder surface, +generators are constructed as explicit applications and then applied to a relation: ```incan -from pub::inql.functions import col +from pub::inql.functions import col, explode items = ( orders - .explode(col("line_items"), as_="line_item") + .generate(explode(col("line_items"), "line_item")) .select(["order_id", "line_item"]) ) ``` @@ -64,13 +65,13 @@ Generator functions must be registry entries with function class `generator` or `explode_outer(array_expr)` must preserve the input row when the input array is null or empty and must produce a null generated value according to its output schema. -`posexplode(array_expr)` and `posexplode_outer(array_expr)` must include a positional output column in addition to the generated element. The position origin must be specified before this RFC reaches Planned status. +`posexplode(array_expr)` and `posexplode_outer(array_expr)` must include a positional output column in addition to the generated element. Positional output is zero-based because `posexplode` follows the Spark-compatible naming convention rather than InQL's one-based scalar collection indexing rule. `inline(array_of_struct_expr)` must expand each struct element into output columns. `inline_outer` must preserve outer rows for null or empty input according to the outer generator rule. `stack` must construct multiple output rows from explicit expressions according to a declared row count and output schema. -`flatten` must be treated as a table-valued/generator operation when supported. Its exact input type, recursive behavior, path behavior, and output columns must be specified before it reaches Planned status. +`flatten` must be treated as a table-valued/generator operation when supported. Portable InQL does not yet define Snowflake-style recursive/path flattening; scalar `array_flatten(...)` remains part of RFC 020 and does not change row cardinality. Every generator must define output column names, output types, nullability, interaction with existing columns, and aliasing requirements. Name collisions must be diagnosed unless an explicit overwrite or qualification rule applies. @@ -78,11 +79,11 @@ Every generator must define output column names, output types, nullability, inte ### Syntax -Generators may appear as dataframe relation methods, query-block clauses, or table-valued function forms. Regardless of syntax, they must lower to relation-shaping operations. +Generators may appear as dataframe relation methods, query-block clauses, or table-valued function forms. Regardless of syntax, they must lower to relation-shaping operations. The initial builder API uses `generate(generator)` to avoid overloading the existing zero-argument compatibility `explode()` method. ### Semantics -Generator output schema is part of the relation schema after the generator operation. Generators may preserve input columns, replace a nested column with generated columns, or produce a new relation depending on the function and syntax, but the behavior must be explicit. +Generator output schema is part of the relation schema after the generator operation. The initial portable generator applications preserve all input columns and append generated output columns in declaration order. Generated aliases are required, must be non-empty, and must not collide with existing columns. ### Interaction with other InQL surfaces @@ -112,11 +113,16 @@ Existing unnest/explode behavior should align with this RFC. If current behavior - **Execution / interchange** — Prism and Substrait lowering must represent cardinality changes and output schemas faithfully. - **Documentation** — generator docs should explain cardinality and schema effects before listing helper names. -## Unresolved questions +## Design Decisions -- Should positional generators use zero-based or one-based positions? -- Should `.explode(...)` preserve all input columns by default? -- What aliasing syntax should be required for generated output columns? -- What subset of Snowflake-style `flatten` behavior belongs in portable InQL versus a warehouse compatibility extension? +### Resolved - +- Positional generators use zero-based positions for compatibility with the `posexplode` naming convention. +- Explicit generator applications preserve all input columns by default and append generated output columns. +- Generated aliases are required at builder construction time. +- Snowflake-style recursive/path `flatten` remains outside the portable core until its output schema and compatibility category are specified separately. + +### Remaining + +- `inline`, `inline_outer`, `stack`, and portable table-valued `flatten` need separate helper slices on top of the generator application model. +- Query-block generator syntax still needs compiler/query-surface work. diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index fac71de..c42c434 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -27,7 +27,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [018][rfc-018] | Draft | Common scalar function catalog | | | [019][rfc-019] | Draft | Window functions | | | [020][rfc-020] | Draft | Nested data functions | | -| [021][rfc-021] | Draft | Generator and table-valued functions | | +| [021][rfc-021] | In Progress | Generator and table-valued functions | | | [022][rfc-022] | Draft | Semi-structured and format functions | | | [023][rfc-023] | Draft | Approximate and sketch functions | | | [024][rfc-024] | Draft | Function extension policy | | diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index fa850bd..e9b31b1 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -22,6 +22,7 @@ The current method-chain surface in this module is the explicit builder-based AP - `with_column(name: str, expr: ColumnExpr)` - `group_by(columns: list[ColumnExpr])` - `agg(measures: list[AggregateMeasure])` +- `generate(generator: GeneratorApplication)` - plus the structural operators `join`, `select`, `order_by`, `limit`, and `explode` Illustrative current-shape examples: @@ -53,6 +54,7 @@ See also: from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr from dataset.materialization import DataFrameMaterialization from substrait.errors import SubstraitLoweringError @@ -63,6 +65,7 @@ from dataset.ops import ( agg_ds_of_columns, explode_ds, filter_ds_of_columns, + generate_ds_of_columns, group_by_ds_of_columns, join_ds, limit_ds, @@ -76,6 +79,7 @@ from prism import ( prism_cursor_apply_agg, prism_cursor_apply_explode, prism_cursor_apply_filter, + prism_cursor_apply_generate, prism_cursor_apply_group_by, prism_cursor_apply_join, prism_cursor_apply_limit, @@ -98,6 +102,7 @@ pub trait DataSet[T with Clone]: def with_column(self, name: str, expr: ColumnExpr) -> Self def group_by(self, columns: list[ColumnExpr]) -> Self def agg(self, measures: list[AggregateMeasure]) -> Self + def generate(self, generator: GeneratorApplication) -> Self def order_by(self, columns: list[ColumnExpr]) -> Self def limit(self, n: int) -> Self def explode(self) -> Self @@ -207,6 +212,12 @@ pub class DataFrame[T with Clone] with BoundedDataSet: agg_ds_of_columns(self._substrait_rel, self.planned_columns(), measures), ) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new DataFrame with a generator stage and stale materialization cleared.""" + return _data_frame_with_invalidated_materialization( + generate_ds_of_columns(self._substrait_rel, self.planned_columns(), generator), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" return _data_frame_with_invalidated_materialization( @@ -288,6 +299,10 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Return one new lazy carrier with an appended aggregation stage.""" return LazyFrame(_cursor=prism_cursor_apply_agg(self._cursor, measures)) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new lazy carrier with an appended generator stage.""" + return LazyFrame(_cursor=prism_cursor_apply_generate(self._cursor, generator)) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new lazy carrier with an appended ordering stage.""" return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor, columns)) @@ -430,6 +445,17 @@ pub class DataStream[T with Clone] with UnboundedDataSet: ), ) + def generate(self, generator: GeneratorApplication) -> Self: + """Return one new DataStream with a generator stage.""" + return DataStream( + _row_schema_marker=self._row_schema_marker.clone(), + _substrait_rel=generate_ds_of_columns( + self._substrait_rel, + relation_output_columns(self._substrait_rel.clone()), + generator, + ), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataStream with an ordering stage.""" return DataStream( diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index 5319f4d..bafad30 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -8,6 +8,7 @@ views stay aligned with the lowered relation tree. from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment from substrait.function_extensions import explode_extension_uri from substrait.inspect import relation_output_columns @@ -19,6 +20,7 @@ from substrait.relations import ( join_rel, project_rel_of_columns, sort_rel_of_columns, + generator_rel_of_columns, ) @@ -122,6 +124,16 @@ pub def agg_ds_of_columns(rel: Rel, input_columns: list[str], measures: list[Agg return aggregate_rel_of_columns(rel, input_columns, [], measures) +pub def generate_ds(rel: Rel, generator: GeneratorApplication) -> Rel: + """Apply one relation-shaping generator to a relation.""" + return generate_ds_of_columns(rel, relation_output_columns(rel.clone()), generator) + + +pub def generate_ds_of_columns(rel: Rel, input_columns: list[str], generator: GeneratorApplication) -> Rel: + """Apply one relation-shaping generator using explicit input-column names.""" + return generator_rel_of_columns(rel, input_columns, generator) + + pub def order_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: """ Apply dataset-level ordering intent to one relation. diff --git a/src/function_registry.incn b/src/function_registry.incn index b5642f9..2ac97ff 100644 --- a/src/function_registry.incn +++ b/src/function_registry.incn @@ -75,6 +75,7 @@ pub enum SubstraitMappingKind(str): CoreFunction = "core_function" ExtensionFunction = "extension_function" + RelationExtension = "relation_extension" Rewrite = "rewrite" StructuralFunction = "structural_function" @@ -294,6 +295,18 @@ pub def extension_mapping(function_name: str, anchor: u32) -> SubstraitMapping: ) +pub def relation_extension_mapping(function_name: str, uri: str) -> SubstraitMapping: + """Build one registered Substrait relation-extension mapping.""" + return SubstraitMapping( + kind=SubstraitMappingKind.RelationExtension, + uri=uri, + function_name=function_name, + anchor=0, + rewrite="", + detail="extension_single", + ) + + pub def core_mapping(function_name: str) -> SubstraitMapping: """Build one mapping for a built-in Substrait Rex shape rather than an extension function declaration.""" return SubstraitMapping( diff --git a/src/functions/generators/explode.incn b/src/functions/generators/explode.incn new file mode 100644 index 0000000..1b6f2ed --- /dev/null +++ b/src/functions/generators/explode.incn @@ -0,0 +1,42 @@ +"""Inner explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, explode as explode_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import explode_extension_uri + + +@function_registry.add("explode", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("explode", explode_extension_uri()), +)) +pub def explode(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build an inner row-expanding generator for array values. + + Examples: + generated = explode(col("line_items"), "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated value column. + """ + return explode_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_explode_builds_generator_application() -> None: + generator = explode(col("line_items"), "line_item") + assert generator.canonical_name == "explode" + assert generator.output_columns[0] == "line_item" diff --git a/src/functions/generators/explode_outer.incn b/src/functions/generators/explode_outer.incn new file mode 100644 index 0000000..bdbc1c9 --- /dev/null +++ b/src/functions/generators/explode_outer.incn @@ -0,0 +1,42 @@ +"""Outer explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, explode_outer as explode_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import explode_outer_extension_uri + + +@function_registry.add("explode_outer", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("explode_outer", explode_outer_extension_uri()), +)) +pub def explode_outer(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """ + Build an outer row-expanding generator for array values. + + Examples: + generated = explode_outer(col("line_items"), "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + as_: Output alias for the generated nullable value column. + """ + return explode_outer_builder(expr, as_) + + +module tests: + from projection_builders import col + def test_explode_outer_builds_outer_generator_application() -> None: + generator = explode_outer(col("line_items"), "line_item") + assert generator.canonical_name == "explode_outer" + assert generator.is_outer diff --git a/src/functions/generators/mod.incn b/src/functions/generators/mod.incn new file mode 100644 index 0000000..4865e2b --- /dev/null +++ b/src/functions/generators/mod.incn @@ -0,0 +1,6 @@ +"""Relation-shaping generator helpers.""" + +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer diff --git a/src/functions/generators/posexplode.incn b/src/functions/generators/posexplode.incn new file mode 100644 index 0000000..b4d5185 --- /dev/null +++ b/src/functions/generators/posexplode.incn @@ -0,0 +1,44 @@ +"""Inner positional explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, posexplode as posexplode_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import posexplode_extension_uri + + +@function_registry.add("posexplode", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("posexplode", posexplode_extension_uri()), +)) +pub def posexplode(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """ + Build an inner row-expanding generator with a zero-based position column. + + Examples: + generated = posexplode(col("line_items"), "position", "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + position_as: Output alias for the zero-based position column. + value_as: Output alias for the generated value column. + """ + return posexplode_builder(expr, position_as, value_as) + + +module tests: + from projection_builders import col + def test_posexplode_builds_positional_generator_application() -> None: + generator = posexplode(col("line_items"), "position", "line_item") + assert generator.canonical_name == "posexplode" + assert generator.position_origin == 0 + assert generator.output_columns[0] == "position" diff --git a/src/functions/generators/posexplode_outer.incn b/src/functions/generators/posexplode_outer.incn new file mode 100644 index 0000000..20bda72 --- /dev/null +++ b/src/functions/generators/posexplode_outer.incn @@ -0,0 +1,44 @@ +"""Outer positional explode generator helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + relation_extension_mapping, + v0_1, +) +from functions.registry import function_registry +from generator_builders import GeneratorApplication, posexplode_outer as posexplode_outer_builder +from projection_builders import ColumnExpr +from substrait.function_extensions import posexplode_outer_extension_uri + + +@function_registry.add("posexplode_outer", deterministic_spec( + FunctionClass.Generator, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.DependsOnInputs, + relation_extension_mapping("posexplode_outer", posexplode_outer_extension_uri()), +)) +pub def posexplode_outer(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """ + Build an outer row-expanding generator with a zero-based position column. + + Examples: + generated = posexplode_outer(col("line_items"), "position", "line_item") + + Parameters: + expr: Array expression to expand into generated rows. + position_as: Output alias for the zero-based position column. + value_as: Output alias for the generated nullable value column. + """ + return posexplode_outer_builder(expr, position_as, value_as) + + +module tests: + from projection_builders import col + def test_posexplode_outer_builds_outer_positional_generator_application() -> None: + generator = posexplode_outer(col("line_items"), "position", "line_item") + assert generator.canonical_name == "posexplode_outer" + assert generator.is_outer + assert generator.output_columns[1] == "line_item" diff --git a/src/functions/mod.incn b/src/functions/mod.incn index e0e8754..c20b662 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -61,6 +61,10 @@ pub from functions.nested.map_from_arrays import map_from_arrays pub from functions.nested.map_keys import map_keys pub from functions.nested.map_values import map_values pub from functions.nested.named_struct import named_struct +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/generator_builders.incn b/src/generator_builders.incn new file mode 100644 index 0000000..d3d6b16 --- /dev/null +++ b/src/generator_builders.incn @@ -0,0 +1,150 @@ +""" +Relation-shaping generator builder surface. + +Generators are not scalar expressions: they may change row cardinality and append output columns. This module carries +the authoring intent through Dataset, Prism, and Substrait boundaries without making generators valid in ordinary +row-level expression positions. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import function_ref_for +from projection_builders import ColumnExpr + + +@derive(Clone) +pub enum GeneratorKind(str): + """Supported relation-shaping generator kinds in the current portable slice.""" + + Explode = "explode" + ExplodeOuter = "explode_outer" + PosExplode = "posexplode" + PosExplodeOuter = "posexplode_outer" + + +@derive(Clone) +pub model GeneratorApplication: + """One registry-backed relation-shaping generator application.""" + + pub kind: GeneratorKind + pub function_ref: str + pub canonical_name: str + pub expr: ColumnExpr + pub output_columns: list[str] + pub preserves_input_columns: bool + pub is_outer: bool + pub position_origin: int + + +pub def explode(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build an inner `explode` generator that appends one value column.""" + return _generator_application("explode", GeneratorKind.Explode, expr, [as_], true, false, 0) + + +pub def explode_outer(expr: ColumnExpr, as_: str) -> GeneratorApplication: + """Build an outer `explode` generator that appends one nullable value column.""" + return _generator_application("explode_outer", GeneratorKind.ExplodeOuter, expr, [as_], true, true, 0) + + +pub def posexplode(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """Build an inner positional explode generator with zero-based positions.""" + return _generator_application("posexplode", GeneratorKind.PosExplode, expr, [position_as, value_as], true, false, 0) + + +pub def posexplode_outer(expr: ColumnExpr, position_as: str, value_as: str) -> GeneratorApplication: + """Build an outer positional explode generator with zero-based positions.""" + return _generator_application( + "posexplode_outer", + GeneratorKind.PosExplodeOuter, + expr, + [position_as, value_as], + true, + true, + 0, + ) + + +pub def generator_output_columns(input_columns: list[str], generator: GeneratorApplication) -> list[str]: + """Return output columns after applying one generator to the provided input columns.""" + mut output_columns: list[str] = [] + if generator.preserves_input_columns: + output_columns.extend(input_columns) + for output_column in generator.output_columns: + if _contains_text(output_columns, output_column): + message = f"generator output column `{output_column}` conflicts with an existing column" + return raise_value_error(message) + output_columns.append(output_column) + return output_columns + + +pub def generator_primary_output_column(generator: GeneratorApplication) -> str: + """Return the primary generated value column for inspection and tests.""" + if len(generator.output_columns) == 0: + return "" + return generator.output_columns[len(generator.output_columns) - 1] + + +def _generator_application( + canonical_name: str, + kind: GeneratorKind, + expr: ColumnExpr, + output_columns: list[str], + preserves_input_columns: bool, + is_outer: bool, + position_origin: int, +) -> GeneratorApplication: + """Build one generator application after validating declared output aliases.""" + _validate_output_columns(canonical_name, output_columns) + return GeneratorApplication( + kind=kind, + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + expr=expr, + output_columns=output_columns, + preserves_input_columns=preserves_input_columns, + is_outer=is_outer, + position_origin=position_origin, + ) + + +def _validate_output_columns(canonical_name: str, output_columns: list[str]) -> None: + """Validate mandatory generator output aliases.""" + if len(output_columns) == 0: + message = f"{canonical_name} requires at least one output alias" + return raise_value_error(message) + mut seen: list[str] = [] + for output_column in output_columns: + if len(output_column) == 0: + message = f"{canonical_name} output aliases must be non-empty" + return raise_value_error(message) + if _contains_text(seen, output_column): + message = f"{canonical_name} output alias `{output_column}` is duplicated" + return raise_value_error(message) + seen.append(output_column) + return + + +def _contains_text(values: list[str], expected: str) -> bool: + """Return whether a string list contains a value.""" + for value in values: + if value == expected: + return true + return false + + +module tests: + from projection_builders import col, column_expr_name + def test_explode_application_records_function_identity_and_output_column() -> None: + generator = explode(col("line_items"), "line_item") + assert generator.kind == GeneratorKind.Explode + assert generator.canonical_name == "explode" + assert generator.function_ref == "inql.functions.explode" + assert column_expr_name(generator.expr) == "line_items" + assert generator.output_columns[0] == "line_item" + assert generator.preserves_input_columns + assert not generator.is_outer + def test_posexplode_uses_zero_based_position_origin() -> None: + generator = posexplode(col("line_items"), "pos", "line_item") + assert generator.kind == GeneratorKind.PosExplode + assert generator.position_origin == 0 + assert generator.output_columns[0] == "pos" + assert generator.output_columns[1] == "line_item" diff --git a/src/lib.incn b/src/lib.incn index a1767b6..87604d7 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -6,8 +6,24 @@ Consumers depend on this package via `[dependencies]` and import with `from pub: """ pub from dataset import BoundedDataSet, DataFrame, DataSet, DataStream, LazyFrame, UnboundedDataSet -pub from dataset.ops import agg_ds, explode_ds, filter_ds, group_by_ds, join_ds, limit_ds, order_by_ds, select_ds +pub from dataset.ops import ( + agg_ds, + explode_ds, + filter_ds, + generate_ds, + group_by_ds, + join_ds, + limit_ds, + order_by_ds, + select_ds, +) pub from aggregate_builders import AggregateKind, AggregateMeasure +pub from generator_builders import ( + GeneratorApplication, + GeneratorKind, + generator_output_columns, + generator_primary_output_column, +) pub from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -81,6 +97,10 @@ pub from functions.nested.map_from_arrays import map_from_arrays pub from functions.nested.map_keys import map_keys pub from functions.nested.map_values import map_values pub from functions.nested.named_struct import named_struct +pub from functions.generators.explode import explode +pub from functions.generators.explode_outer import explode_outer +pub from functions.generators.posexplode import posexplode +pub from functions.generators.posexplode_outer import posexplode_outer pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -142,6 +162,7 @@ pub from function_registry import ( function_policy_spec, namespaced_function_ref, rejected_function_policy, + relation_extension_mapping, rewrite_mapping, sort_field_mapping, structural_mapping, @@ -184,6 +205,8 @@ pub from substrait.relations import ( extension_single_rel, fetch_rel, filter_rel, + generator_rel, + generator_rel_of_columns, join_rel, join_rel_of_kind, project_rel, @@ -211,6 +234,8 @@ pub from substrait.inspect import ( aggregate_measure_function_names, aggregate_measure_invocation_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -222,7 +247,10 @@ pub from substrait.inspect import ( ) pub from substrait.function_extensions import ( explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, registered_substrait_extension_uris, ) pub from substrait.conformance_catalog import ( diff --git a/src/prism/lower.incn b/src/prism/lower.incn index 9ae303c..6020b57 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -10,6 +10,7 @@ from substrait.relations import ( fetch_rel, join_rel, read_named_table_rel, + try_generator_rel_of_columns, sort_rel_of_columns, try_aggregate_rel_of_columns, try_filter_rel_of_columns, @@ -118,6 +119,12 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst [], node.aggregate_measures, ) + PrismNodeKind.Generate => + return try_generator_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, + rewritten_output_columns(view, node.input_ids[0]), + node.generator_applications[0], + ) PrismNodeKind.OrderBy => return Ok( sort_rel_of_columns( diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 229cbaa..3564d35 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -13,6 +13,7 @@ This façade keeps one stable internal import surface while the implementation i from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure from filter_builders import always_true +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment from prism.lower import ( lower_prism_tip as lower_prism_tip_impl, @@ -69,6 +70,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -87,6 +89,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -102,6 +105,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -119,6 +123,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -136,6 +141,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[with_column_assignment(name, expr)], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -153,6 +159,7 @@ pub class PrismCursor[T with Clone]: group_columns=columns, sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -170,6 +177,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=measures, + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -187,6 +195,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=columns, aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -204,6 +213,7 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -221,6 +231,25 @@ pub class PrismCursor[T with Clone]: group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], + projection_assignments=[], + ) + return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) + + def generate(self, generator: GeneratorApplication) -> Self: + """Append one explicit generator node and return the derived tip.""" + next_tip_id = append_node( + store_id=self.store_id, + kind=PrismNodeKind.Generate, + input_ids=[self.tip_id], + named_table="", + join_predicate=false, + filter_predicate=always_true(), + limit_count=0, + group_columns=[], + sort_columns=[], + aggregate_measures=[], + generator_applications=[generator], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -264,6 +293,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) return PrismCursor(store_id=store_id, tip_id=tip_id, _type_marker=[]) @@ -325,6 +355,14 @@ pub def prism_cursor_apply_explode[T with Clone](cursor: PrismCursor[T]) -> Pris return cursor.explode() +pub def prism_cursor_apply_generate[T with Clone]( + cursor: PrismCursor[T], + generator: GeneratorApplication, +) -> PrismCursor[T]: + """Apply one explicit generator through Prism.""" + return cursor.generate(generator) + + pub def prism_cursor_output_columns[T with Clone](cursor: PrismCursor[T]) -> list[str]: """Return plan-time output columns for one cursor tip.""" return cursor.planned_columns() diff --git a/src/prism/output_columns.incn b/src/prism/output_columns.incn index f1de58c..d1cfa06 100644 --- a/src/prism/output_columns.incn +++ b/src/prism/output_columns.incn @@ -3,6 +3,7 @@ from prism.store import node_at from prism.rewrite import rewritten_node_at from prism.types import PrismNodeKind, PrismOptimizedView, PrismStoreId +from generator_builders import generator_output_columns from projection_builders import ColumnExpr, project_output_columns, scalar_expr_output_name from substrait.inspect import aggregate_measure_output_names from substrait.schema_registry import named_table_columns @@ -27,6 +28,11 @@ pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str return authored_output_columns(store_id, node.input_ids[0]) if node.kind == PrismNodeKind.Project: return project_output_columns(authored_output_columns(store_id, node.input_ids[0]), node.projection_assignments) + if node.kind == PrismNodeKind.Generate: + return generator_output_columns( + authored_output_columns(store_id, node.input_ids[0]), + node.generator_applications[0], + ) if node.kind == PrismNodeKind.Join: # Join output columns preserve the conventional left-then-right relation order. # We keep both sides verbatim here; duplicate names are part of the current output shape and are resolved later @@ -59,6 +65,11 @@ pub def rewritten_output_columns(view: PrismOptimizedView, node_id: int) -> list return rewritten_output_columns(view, node.input_ids[0]) if node.kind == PrismNodeKind.Project: return project_output_columns(rewritten_output_columns(view, node.input_ids[0]), node.projection_assignments) + if node.kind == PrismNodeKind.Generate: + return generator_output_columns( + rewritten_output_columns(view, node.input_ids[0]), + node.generator_applications[0], + ) if node.kind == PrismNodeKind.Join: # Rewritten views keep the same left-then-right join column order as authored views # so output-column inference stays stable across Prism rewrite passes. diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index 6247b0b..419f968 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -168,6 +168,7 @@ def _build_collapsed_limit_node( group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) @@ -204,6 +205,7 @@ def _build_collapsed_project_node( group_columns=[], sort_columns=[], aggregate_measures=[], + generator_applications=[], projection_assignments=merged_assignments, ) @@ -240,6 +242,7 @@ def _build_collapsed_aggregate_node( group_columns=[], sort_columns=[], aggregate_measures=merged_measures, + generator_applications=[], projection_assignments=[], ) @@ -274,6 +277,7 @@ def _build_collapsed_order_by_node( group_columns=[], sort_columns=node.sort_columns, aggregate_measures=[], + generator_applications=[], projection_assignments=[], ) @@ -291,6 +295,7 @@ def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten group_columns=node.group_columns, sort_columns=node.sort_columns, aggregate_measures=node.aggregate_measures, + generator_applications=node.generator_applications, projection_assignments=node.projection_assignments, ) @@ -336,6 +341,7 @@ def _compact_optimized_view(view: PrismOptimizedView) -> PrismOptimizedView: group_columns=old_node.group_columns, sort_columns=old_node.sort_columns, aggregate_measures=old_node.aggregate_measures, + generator_applications=old_node.generator_applications, projection_assignments=old_node.projection_assignments, ), ) diff --git a/src/prism/store.incn b/src/prism/store.incn index d620574..e451ade 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -1,6 +1,7 @@ """Append-only Prism store allocation, storage, reachability, and cross-store adoption.""" from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -54,6 +55,7 @@ pub def append_node( group_columns: list[ColumnExpr], sort_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], + generator_applications: list[GeneratorApplication], projection_assignments: list[ProjectionAssignment], ) -> int: """ @@ -73,6 +75,7 @@ pub def append_node( group_columns=group_columns, sort_columns=sort_columns, aggregate_measures=aggregate_measures, + generator_applications=generator_applications, projection_assignments=projection_assignments, ) prism_stored_nodes.append(PrismStoredNode(store_id_raw=store_id.0, node=appended)) @@ -119,10 +122,12 @@ pub def adopt_cursor_subgraph( adopted_group_columns = [column for column in source_node.group_columns] adopted_sort_columns = [column for column in source_node.sort_columns] adopted_measures = [measure for measure in source_node.aggregate_measures] + adopted_generators = [generator for generator in source_node.generator_applications] adopted_assignments = [assignment for assignment in source_node.projection_assignments] target_group_columns = [column for column in source_node.group_columns] target_sort_columns = [column for column in source_node.sort_columns] target_measures = [measure for measure in source_node.aggregate_measures] + target_generators = [generator for generator in source_node.generator_applications] target_assignments = [assignment for assignment in source_node.projection_assignments] adopted_id = append_node( store_id=target_store_id, @@ -135,6 +140,7 @@ pub def adopt_cursor_subgraph( group_columns=adopted_group_columns, sort_columns=adopted_sort_columns, aggregate_measures=adopted_measures, + generator_applications=adopted_generators, projection_assignments=adopted_assignments, ) target_store_nodes.append( @@ -149,6 +155,7 @@ pub def adopt_cursor_subgraph( group_columns=target_group_columns, sort_columns=target_sort_columns, aggregate_measures=target_measures, + generator_applications=target_generators, projection_assignments=target_assignments, ), ) @@ -232,6 +239,11 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return false if not _aggregate_measure_lists_structurally_equal(candidate.aggregate_measures, source_node.aggregate_measures): return false + if not _generator_application_lists_structurally_equal( + candidate.generator_applications, + source_node.generator_applications, + ): + return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, source_node.projection_assignments, @@ -271,6 +283,48 @@ def _aggregate_measures_structurally_equal(left: AggregateMeasure, right: Aggreg return _column_expr_lists_structurally_equal(left.ordering, right.ordering) +def _generator_application_lists_structurally_equal( + left: list[GeneratorApplication], + right: list[GeneratorApplication], +) -> bool: + """Return whether two generator-application lists carry identical relation-shaping semantics.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if not _generator_applications_structurally_equal(left[idx], right[idx]): + return false + return true + + +def _generator_applications_structurally_equal(left: GeneratorApplication, right: GeneratorApplication) -> bool: + """Return whether two generator applications carry identical registry identity and schema effects.""" + if left.kind != right.kind: + return false + if left.function_ref != right.function_ref: + return false + if left.canonical_name != right.canonical_name: + return false + if left.preserves_input_columns != right.preserves_input_columns: + return false + if left.is_outer != right.is_outer: + return false + if left.position_origin != right.position_origin: + return false + if not _text_lists_structurally_equal(left.output_columns, right.output_columns): + return false + return _column_exprs_structurally_equal(left.expr, right.expr) + + +def _text_lists_structurally_equal(left: list[str], right: list[str]) -> bool: + """Return whether two string lists are structurally equivalent.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if left[idx] != right[idx]: + return false + return true + + def _filter_predicates_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: """Return whether two filter scalar expressions are structurally equivalent.""" return _column_exprs_structurally_equal(left, right) diff --git a/src/prism/types.incn b/src/prism/types.incn index a5573cf..59472c1 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -1,6 +1,7 @@ """Shared Prism types that define the internal planning substrate contract.""" from aggregate_builders import AggregateMeasure +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment @@ -17,6 +18,7 @@ pub enum PrismNodeKind(str): Project = "Project" GroupBy = "GroupBy" Aggregate = "Aggregate" + Generate = "Generate" OrderBy = "OrderBy" Limit = "Limit" Explode = "Explode" @@ -41,6 +43,7 @@ pub model PrismNode: pub group_columns: list[ColumnExpr] pub sort_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] + pub generator_applications: list[GeneratorApplication] pub projection_assignments: list[ProjectionAssignment] diff --git a/src/substrait/expr_lowering.incn b/src/substrait/expr_lowering.incn index d5dcd72..21384eb 100644 --- a/src/substrait/expr_lowering.incn +++ b/src/substrait/expr_lowering.incn @@ -281,6 +281,12 @@ def _resolved_scalar_function_application_expr( f"{entry.function_ref} is only valid in {entry.substrait.function_name} context", ), ) + SubstraitMappingKind.RelationExtension => + return Err( + invalid_scalar_expression( + f"{entry.function_ref} is a relation-shaping generator and must be applied through generate(...)", + ), + ) SubstraitMappingKind.Rewrite => return Err( invalid_scalar_expression( diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index f4efeb4..4f0edad 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -6,6 +6,7 @@ expression trees. """ from rust::incan_stdlib::errors import raise_value_error +from rust::std::primitive import u32 as RustU32 from rust::substrait::proto import AggregateFunction, Expression, FunctionArgument, Rel, SortField from rust::substrait::proto::extensions import SimpleExtensionDeclaration, SimpleExtensionUrn from rust::substrait::proto::extensions::simple_extension_declaration import ExtensionFunction, MappingType @@ -28,7 +29,23 @@ model ExtensionUrnSpec: const FUNCTION_EXTENSION_URN_ANCHOR: u32 = 0 -const RELATION_EXTENSION_URN_ANCHOR: u32 = 1 + + +def _to_extension_urn_anchor(value: int) -> RustU32: + """Convert a small extension-URN anchor into the protobuf field type.""" + match RustU32.try_from(value): + Ok(converted) => return converted + Err(_) => + message = f"extension URN anchor {value} does not fit Rust u32" + return raise_value_error(message) + + +def _has_extension_urn_spec(specs: list[ExtensionUrnSpec], urn: str) -> bool: + """Return whether a plan-level extension URN list already contains one URI.""" + for spec in specs: + if spec.urn == urn: + return true + return false pub def aggregate_function_name_from_anchor(anchor: u32) -> str: @@ -407,6 +424,10 @@ pub def extension_urns_for_rel(rel: Rel) -> list[SimpleExtensionUrn]: mut specs: list[ExtensionUrnSpec] = [] if _function_extension_urn_is_required(rel.clone()): specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=function_extension_uri())) + mut relation_anchor_count = 0 for urn in _collect_extension_urn_strings(rel): - specs.append(ExtensionUrnSpec(anchor=RELATION_EXTENSION_URN_ANCHOR, urn=urn)) + if _has_extension_urn_spec(specs, urn): + continue + relation_anchor_count += 1 + specs.append(ExtensionUrnSpec(anchor=_to_extension_urn_anchor(relation_anchor_count), urn=urn)) return [SimpleExtensionUrn(extension_urn_anchor=spec.anchor, urn=spec.urn) for spec in specs] diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 490f93c..649a680 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -77,6 +77,9 @@ pub const ARRAY_HAS_ANY_FUNCTION_ANCHOR: u32 = 50 pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" +const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" +const POSEXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#posexplode" +const POSEXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#posexplode_outer" pub def function_extension_uri() -> str: @@ -89,6 +92,21 @@ pub def explode_extension_uri() -> str: return EXPLODE_EXTENSION_URI +pub def explode_outer_extension_uri() -> str: + """Return the registered extension URI used for outer EXPLODE gap encoding.""" + return EXPLODE_OUTER_EXTENSION_URI + + +pub def posexplode_extension_uri() -> str: + """Return the registered extension URI used for positional EXPLODE gap encoding.""" + return POSEXPLODE_EXTENSION_URI + + +pub def posexplode_outer_extension_uri() -> str: + """Return the registered extension URI used for outer positional EXPLODE gap encoding.""" + return POSEXPLODE_OUTER_EXTENSION_URI + + pub def registered_substrait_extension_uris() -> list[str]: """Return the registered extension URIs used by current package-level Substrait lowering.""" - return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI] + return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI, EXPLODE_OUTER_EXTENSION_URI, POSEXPLODE_EXTENSION_URI, POSEXPLODE_OUTER_EXTENSION_URI] diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index 37005e0..063061f 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -19,8 +19,14 @@ from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from projection_builders import scalar_expr_output_name -from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit +from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit, rust_u32_to_int from substrait.extensions import aggregate_function_name_from_anchor +from substrait.function_extensions import ( + explode_extension_uri, + explode_outer_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, +) from substrait.schema_registry import named_table_columns, unknown_named_struct from substrait.traversal import relation_children @@ -185,7 +191,11 @@ def _relation_output_columns(rel: Rel) -> list[str]: None => return [] Some(RelType.ExtensionSingle(extension_rel)) => match extension_rel.input: - Some(child) => return _relation_output_columns(child.as_ref().clone()) + Some(child) => + input_columns = _relation_output_columns(child.as_ref().clone()) + match extension_rel.detail: + Some(detail) => return _extension_single_output_columns(input_columns, detail.type_url) + None => return input_columns None => return [] Some(RelType.Join(join_rel)) => mut names: list[str] = [] @@ -218,6 +228,18 @@ pub def relation_output_columns(rel: Rel) -> list[str]: return _relation_output_columns(rel) +def _extension_single_output_columns(input_columns: list[str], extension_uri: str) -> list[str]: + """Return best-effort output columns for known extension-single relation encodings.""" + mut columns: list[str] = [] + columns.extend(input_columns) + if extension_uri == explode_extension_uri() or extension_uri == explode_outer_extension_uri(): + columns.append("value") + elif extension_uri == posexplode_extension_uri() or extension_uri == posexplode_outer_extension_uri(): + columns.append("position") + columns.append("value") + return columns + + pub def aggregate_measure_function_names(rel: Rel) -> list[str]: """Return aggregate function names used by a top-level AggregateRel, otherwise empty.""" match rel.rel_type: @@ -453,3 +475,15 @@ pub def plan_has_extension_urn(plan: Plan, extension_uri: str) -> bool: if urn.urn == extension_uri: return true return false + + +pub def plan_extension_urn_count(plan: Plan) -> int: + """Return the number of extension URN declarations carried by one plan.""" + return len(plan.extension_urns) + + +pub def plan_extension_urn_anchor_at(plan: Plan, index: int) -> int: + """Return one extension URN anchor as an Incan integer for tests and diagnostics.""" + if index < 0 or index >= len(plan.extension_urns): + return -1 + return rust_u32_to_int(plan.extension_urns[index].extension_urn_anchor) diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index 16e0f38..2f15c20 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -26,6 +26,8 @@ pub from substrait.relations import ( fetch_rel, filter_rel, filter_rel_of_columns, + generator_rel, + generator_rel_of_columns, join_rel, join_rel_of_kind, project_rel, @@ -58,6 +60,8 @@ pub from substrait.inspect import ( aggregate_measure_invocation_names, aggregate_measure_output_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -76,6 +80,9 @@ pub from substrait.inspect import ( ) pub from substrait.function_extensions import ( explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, registered_substrait_extension_uris, ) diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 849beba..b075e5f 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -46,6 +46,7 @@ from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from function_registry import FunctionClass, FunctionRegistryEntry, SubstraitMappingKind from functions.registry import function_registry_entry +from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col from substrait.expr_lowering import ( bool_expr, @@ -81,6 +82,15 @@ model ResolvedRelationExpression: expr: Expression +@derive(Clone) +model ResolvedGeneratorApplication: + """One generator application resolved against input columns and registry metadata.""" + + generator: GeneratorApplication + entry: FunctionRegistryEntry + expr: Expression + + pub enum SubstraitJoinKind: Inner Left @@ -259,6 +269,61 @@ def _validate_aggregate_modifiers(measure: ResolvedAggregateMeasure) -> Result[N return Ok(None) +def _generator_registry_entry(generator: GeneratorApplication) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one generator registry entry and validate its semantic class.""" + match function_registry_entry(generator.function_ref): + Some(entry) => + if entry.function_class != FunctionClass.Generator: + return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as a generator function")) + if entry.substrait.kind != SubstraitMappingKind.RelationExtension: + return Err( + invalid_scalar_expression(f"{entry.function_ref} does not declare a relation-extension mapping"), + ) + return Ok(entry) + None => + return Err(invalid_scalar_expression(f"missing generator registry entry for `{generator.canonical_name}`")) + + +def _resolved_generator( + generator: GeneratorApplication, + input_columns: list[str], +) -> Result[ResolvedGeneratorApplication, SubstraitLoweringError]: + """Resolve one generator application against input-column names.""" + _validate_generator_output_columns(input_columns, generator.clone())? + return Ok( + ResolvedGeneratorApplication( + generator=generator.clone(), + entry=_generator_registry_entry(generator.clone())?, + expr=scalar_expr(input_columns, generator.expr)?, + ), + ) + + +def _validate_generator_output_columns( + input_columns: list[str], + generator: GeneratorApplication, +) -> Result[None, SubstraitLoweringError]: + """Validate generator output columns against the current input relation shape.""" + mut output_columns: list[str] = [] + if generator.preserves_input_columns: + output_columns.extend(input_columns) + for output_column in generator.output_columns: + if _contains_text(output_columns, output_column): + return Err( + invalid_scalar_expression(f"generator output column `{output_column}` conflicts with an existing column"), + ) + output_columns.append(output_column) + return Ok(None) + + +def _contains_text(values: list[str], expected: str) -> bool: + """Return whether a string list contains a value.""" + for value in values: + if value == expected: + return true + return false + + def _aggregate_function_reference(measure: ResolvedAggregateMeasure) -> Result[u32, SubstraitLoweringError]: """Resolve one aggregate measure through declaration-side registry metadata.""" match _aggregate_registry_entry(measure): @@ -603,6 +668,31 @@ pub def try_aggregate_rel_of_columns( ) +pub def generator_rel(input: Rel, generator: GeneratorApplication) -> Rel: + """Wrap a child relation in a generator relation-extension node.""" + return _lowered_rel_or_raise(try_generator_rel(input, generator)) + + +pub def try_generator_rel(input: Rel, generator: GeneratorApplication) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a generator relation-extension node.""" + return try_generator_rel_of_columns(input.clone(), relation_output_columns(input), generator) + + +pub def generator_rel_of_columns(input: Rel, input_columns: list[str], generator: GeneratorApplication) -> Rel: + """Wrap a child relation in a generator relation-extension node using explicit input-column names.""" + return _lowered_rel_or_raise(try_generator_rel_of_columns(input, input_columns, generator)) + + +pub def try_generator_rel_of_columns( + input: Rel, + input_columns: list[str], + generator: GeneratorApplication, +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a generator relation-extension node using explicit input-column names.""" + resolved = _resolved_generator(generator, input_columns)? + return Ok(extension_single_rel(input, resolved.entry.substrait.uri)) + + pub def sort_rel(input: Rel) -> Rel: """Wrap a child relation in `SortRel` using the first known output column as the default sort key.""" input_columns = relation_output_columns(input.clone()) diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index 8d762b1..3140a03 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -17,6 +17,8 @@ from functions import ( count_expr, count_if, eq, + explode, + explode_outer, float_expr, int_expr, int_lit, @@ -24,12 +26,19 @@ from functions import ( max, min, mul, + posexplode, + posexplode_outer, str_expr, str_lit, sum, ) from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name -from substrait.function_extensions import explode_extension_uri +from substrait.function_extensions import ( + explode_extension_uri, + explode_outer_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, +) from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len, plan_from_named_table, plan_from_root_relation from substrait.relations import read_named_table_rel @@ -422,6 +431,14 @@ def test_lazy_frame__independent_roots_can_join_and_lower() -> None: def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None: # -- Arrange -- _register_order_schema("orders") + register_named_table_schema( + "orders_generator_dataset", + [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false), RowColumnSpec( + name="line_items", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) projected: LazyFrame[Order] = lazy_frame_named_table("orders").select() grouped: LazyFrame[Order] = lazy_frame_named_table("orders").group_by([col("id")]) @@ -430,6 +447,18 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None ordered: LazyFrame[Order] = lazy_frame_named_table("orders").order_by([col("id")]) limited: LazyFrame[Order] = lazy_frame_named_table("orders").limit(10) exploded: LazyFrame[Order] = lazy_frame_named_table("orders").explode() + generated: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + explode(col("line_items"), "line_item"), + ) + generated_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + explode_outer(col("line_items"), "line_item"), + ) + generated_positional: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + posexplode(col("line_items"), "position", "line_item"), + ) + generated_positional_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( + posexplode_outer(col("line_items"), "position", "line_item"), + ) # -- Assert -- assert relation_kind_name(root_rel(projected.to_substrait_plan())) == "ProjectRel", "select should lower through the project boundary shape" @@ -438,6 +467,11 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None assert relation_kind_name(root_rel(ordered.to_substrait_plan())) == "SortRel", "order_by should lower to SortRel" assert relation_kind_name(root_rel(limited.to_substrait_plan())) == "FetchRel", "limit should lower to FetchRel" assert plan_has_extension_urn(exploded.to_substrait_plan(), explode_extension_uri()), "explode should keep emitting the registered extension boundary" + assert relation_kind_name(root_rel(generated.to_substrait_plan())) == "ExtensionSingleRel", "generate should lower through the relation extension boundary" + assert generated.planned_columns() == ["id", "line_items", "line_item"], "generate should append declared output aliases" + assert plan_has_extension_urn(generated_outer.to_substrait_plan(), explode_outer_extension_uri()), "outer explode should use its relation extension URI" + assert plan_has_extension_urn(generated_positional.to_substrait_plan(), posexplode_extension_uri()), "posexplode should use its relation extension URI" + assert plan_has_extension_urn(generated_positional_outer.to_substrait_plan(), posexplode_outer_extension_uri()), "posexplode_outer should use its relation extension URI" def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 22f8739..424147e 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -45,6 +45,8 @@ from functions import ( eq, equal_null, element_at, + explode, + explode_outer, floor, float_expr, function_registry_canonical_names, @@ -81,6 +83,8 @@ from functions import ( not_, nullif, or_, + posexplode, + posexplode_outer, registered_substrait_mapped_function_refs, round, str_expr, @@ -164,7 +168,11 @@ from substrait.function_extensions import ( ROUND_FUNCTION_ANCHOR, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, + explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, ) @@ -223,7 +231,7 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer"] def _expected_substrait_mapped_names() -> list[str]: @@ -313,6 +321,10 @@ def _exercise_current_public_helpers() -> None: map_keys(attr_map) map_values(attr_map) named_struct(["status", "amount"], [status, amount]) + explode(tags, "tag") + explode_outer(tags, "tag") + posexplode(tags, "position", "tag") + posexplode_outer(tags, "position", "tag") return @@ -346,6 +358,15 @@ def _assert_extension_mapping(canonical_name: str, function_name: str, anchor: u assert entry.substrait.anchor == anchor, f"{canonical_name} should carry the stable Substrait anchor" +def _assert_relation_extension_mapping(canonical_name: str, function_name: str, extension_uri: str) -> None: + """Assert one generator helper declares a relation-extension mapping.""" + entry = _entry_or_fail(function_ref_for(canonical_name)) + assert entry.function_class == FunctionClass.Generator, f"{canonical_name} should be classified as a generator" + assert entry.substrait.kind == SubstraitMappingKind.RelationExtension, f"{canonical_name} should use a relation extension" + assert entry.substrait.uri == extension_uri, f"{canonical_name} should carry the registered relation extension URI" + assert entry.substrait.function_name == function_name, f"{canonical_name} should use the registered extension name" + + def _assert_core_mapping(canonical_name: str, function_name: str) -> None: """Assert one helper declares the expected built-in Substrait Rex mapping.""" entry = _entry_or_fail(function_ref_for(canonical_name)) @@ -608,6 +629,18 @@ def test_function_registry__substrait_extension_mappings_are_structured() -> Non _assert_extension_mapping("named_struct", "named_struct", NAMED_STRUCT_FUNCTION_ANCHOR) +def test_function_registry__generator_helpers_are_relation_extensions() -> None: + """Assert generator helpers are registry entries without scalar or aggregate extension anchors.""" + # -- Arrange -- + _exercise_current_public_helpers() + + # -- Act / Assert -- + _assert_relation_extension_mapping("explode", "explode", explode_extension_uri()) + _assert_relation_extension_mapping("explode_outer", "explode_outer", explode_outer_extension_uri()) + _assert_relation_extension_mapping("posexplode", "posexplode", posexplode_extension_uri()) + _assert_relation_extension_mapping("posexplode_outer", "posexplode_outer", posexplode_outer_extension_uri()) + + def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: """Assert RFC 015 ordering helpers are modeled as sort-field context helpers.""" # -- Arrange -- diff --git a/tests/test_generator_functions.incn b/tests/test_generator_functions.incn new file mode 100644 index 0000000..052d784 --- /dev/null +++ b/tests/test_generator_functions.incn @@ -0,0 +1,55 @@ +"""Tests for registry-backed generator and table-valued function builders.""" + +from std.testing import assert_raises +from generator_builders import GeneratorKind, generator_output_columns, generator_primary_output_column +from functions import col, explode, explode_outer, posexplode, posexplode_outer + + +def test_generator_functions__explode_family_builds_relation_applications() -> None: + # -- Arrange -- + items = col("line_items") + + # -- Act -- + inner = explode(items, "line_item") + outer = explode_outer(items, "line_item") + positional = posexplode(items, "position", "line_item") + positional_outer = posexplode_outer(items, "position", "line_item") + + # -- Assert -- + assert inner.kind == GeneratorKind.Explode + assert outer.kind == GeneratorKind.ExplodeOuter + assert positional.kind == GeneratorKind.PosExplode + assert positional_outer.kind == GeneratorKind.PosExplodeOuter + assert not inner.is_outer + assert outer.is_outer + assert positional.position_origin == 0 + assert positional_outer.position_origin == 0 + assert generator_primary_output_column(inner) == "line_item" + assert generator_primary_output_column(positional) == "line_item" + + +def test_generator_functions__output_columns_preserve_input_then_append_aliases() -> None: + # -- Arrange -- + input_columns = ["id", "line_items"] + + # -- Act -- + exploded_columns = generator_output_columns(input_columns, explode(col("line_items"), "line_item")) + positional_columns = generator_output_columns(input_columns, posexplode(col("line_items"), "position", "line_item")) + + # -- Assert -- + assert exploded_columns == ["id", "line_items", "line_item"] + assert positional_columns == ["id", "line_items", "position", "line_item"] + + +def _call_generator_with_input_collision() -> None: + """Call generator output inference with a generated name that collides with input.""" + generator_output_columns(["id", "line_items"], explode(col("line_items"), "id")) + return + + +def test_generator_functions__output_alias_collisions_are_rejected() -> None: + # -- Arrange -- + call = _call_generator_with_input_collision + + # -- Act / Assert -- + assert_raises[ValueError](call) diff --git a/tests/test_prism.incn b/tests/test_prism.incn index f0c7490..fc1f097 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,9 +1,10 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, count_expr, lit, mul, sum +from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, sum from prism import ( PrismCursor, prism_cursor_apply_filter, + prism_cursor_apply_generate, prism_cursor_apply_limit, prism_cursor_apply_select, prism_cursor_authored_node_count, @@ -35,6 +36,17 @@ def _register_projection_test_schema(table_name: str) -> None: register_named_table_schema(table_name, [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) +def _register_generator_test_schema(table_name: str) -> None: + register_named_table_schema( + table_name, + [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false), RowColumnSpec( + name="line_items", + kind=SubstraitPrimitiveKind.String, + nullable=true, + )], + ) + + def test_prism__branching_keeps_base_reachable_history_small() -> None: # -- Arrange -- base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) @@ -206,6 +218,7 @@ def test_prism__cross_store_adoption_keeps_distinct_aggregate_modifier_state() - def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: # -- Arrange -- _register_projection_test_schema(str("orders")) + _register_generator_test_schema(str("orders_generator_prism")) projected: PrismCursor[Order] = prism_cursor_named_table(str("orders")).select() grouped: PrismCursor[Order] = prism_cursor_named_table(str("orders")).group_by([col("id")]) @@ -214,6 +227,9 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by([col("id")]) limited: PrismCursor[Order] = prism_cursor_named_table(str("orders")).limit(10) exploded: PrismCursor[Order] = prism_cursor_named_table(str("orders")).explode() + generated: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_prism")).generate( + explode(col("line_items"), "line_item"), + ) # -- Assert -- assert prism_cursor_tip_kind_name(projected) == str("Project"), "select should append a native project node" @@ -222,6 +238,8 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: assert prism_cursor_tip_kind_name(ordered) == str("OrderBy"), "order_by should append a native sort node" assert prism_cursor_tip_kind_name(limited) == str("Limit"), "limit should append a native limit node" assert prism_cursor_tip_kind_name(exploded) == str("Explode"), "explode should append a native explode node" + assert prism_cursor_tip_kind_name(generated) == str("Generate"), "generate should append a native generator node" + assert prism_cursor_output_columns(generated) == ["id", "line_items", "line_item"], "generate should append declared output aliases" def test_prism__rewrite_eliminates_filter_true_by_default() -> None: @@ -332,3 +350,21 @@ def test_prism__cursor_methods_match_apply_helpers() -> None: assert relation_kind_name(root_rel(via_methods.to_substrait_plan())) == relation_kind_name( root_rel(via_helpers.to_substrait_plan()), ), "method and helper paths should lower to equivalent root relation kinds" + + +def test_prism__generate_method_matches_apply_helper() -> None: + # -- Arrange -- + _register_generator_test_schema("orders_generator_apply") + base: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_apply")) + generator = explode(col("line_items"), "line_item") + + # -- Act -- + via_method = base.generate(generator) + via_helper = prism_cursor_apply_generate(base, generator) + + # -- Assert -- + assert prism_cursor_tip_kind_name(via_method) == prism_cursor_tip_kind_name(via_helper), "method and helper paths should produce the same generator node kind" + assert prism_cursor_output_columns(via_method) == ["id", "line_items", "line_item"], "generator helper should preserve planned output columns" + assert relation_kind_name(root_rel(via_method.to_substrait_plan())) == relation_kind_name( + root_rel(via_helper.to_substrait_plan()), + ), "generator method and helper paths should lower to equivalent root relation kinds" diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 17a70c6..04bb902 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -75,7 +75,10 @@ from substrait.errors import SubstraitLoweringErrorKind from substrait.expr_lowering import scalar_expr from substrait.function_extensions import ( explode_extension_uri, + explode_outer_extension_uri, function_extension_uri, + posexplode_extension_uri, + posexplode_outer_extension_uri, registered_substrait_extension_uris, ) from substrait.inspect import ( @@ -83,6 +86,8 @@ from substrait.inspect import ( aggregate_measure_filter_flags, aggregate_measure_invocation_names, aggregate_measure_sort_counts, + plan_extension_urn_anchor_at, + plan_extension_urn_count, plan_contains_relation_kind, plan_has_extension_urn, read_kind_name, @@ -591,13 +596,22 @@ def test_plan__extension_urns_are_surfaced() -> None: # -- Arrange -- extension_uri = explode_extension_uri() rel = extension_single_rel(read_named_table_rel("orders"), extension_uri) + nested = extension_single_rel( + extension_single_rel(read_named_table_rel("orders"), explode_extension_uri()), + posexplode_extension_uri(), + ) # -- Act -- plan = plan_from_root_relation(rel, ["id"]) + nested_plan = plan_from_root_relation(nested, ["id", "position", "value"]) # -- Assert -- assert plan_has_extension_urn(plan, extension_uri), "extension relation should populate extension URNs" assert plan_contains_relation_kind(plan, "ExtensionSingleRel"), "extension root should remain inspectable" + assert plan_has_extension_urn(nested_plan, explode_extension_uri()), "nested extension plans should include child extension URNs" + assert plan_has_extension_urn(nested_plan, posexplode_extension_uri()), "nested extension plans should include root extension URNs" + assert plan_extension_urn_count(nested_plan) == 2, "different relation extension URIs should be declared once each" + assert plan_extension_urn_anchor_at(nested_plan, 0) != plan_extension_urn_anchor_at(nested_plan, 1), "relation extension URNs should use distinct anchors" def test_plan__revision_pin_and_extension_registry_are_exported() -> None: @@ -611,9 +625,12 @@ def test_plan__revision_pin_and_extension_registry_are_exported() -> None: # -- Assert -- assert tag == "v0.63.0", "revision helpers should expose the currently targeted Substrait release tag" assert producer == "inql-rfc002", "revision helpers should expose the package producer label" - assert len(registered) == 2, "current package boundary should register both extension URIs" + assert len(registered) == 5, "current package boundary should register function and generator extension URIs" assert registered[0] == function_extension_uri(), "registry should include the shared function extension URI first" assert registered[1] == explode_extension_uri(), "registry should include the emitted explode extension URI" + assert registered[2] == explode_outer_extension_uri(), "registry should include the outer explode extension URI" + assert registered[3] == posexplode_extension_uri(), "registry should include the positional explode extension URI" + assert registered[4] == posexplode_outer_extension_uri(), "registry should include the outer positional explode extension URI" def test_conformance__core_scenarios_validate_emission_output() -> None: