Skip to content

Commit d07fdb3

Browse files
timsaucerclaude
andauthored
Add missing scalar functions (#1470)
* Add missing scalar functions: get_field, union_extract, union_tag, arrow_metadata, version, row Expose upstream DataFusion scalar functions that were not yet available in the Python API. Closes #1453. - get_field: extracts a field from a struct or map by name - union_extract: extracts a value from a union type by field name - union_tag: returns the active field name of a union type - arrow_metadata: returns Arrow field metadata (all or by key) - version: returns the DataFusion version string - row: alias for the struct constructor Note: arrow_try_cast was listed in the issue but does not exist in DataFusion 53, so it is not included. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add tests for new scalar functions Tests for get_field, arrow_metadata, version, row, union_tag, and union_extract. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Accept str for field name and type parameters in scalar functions Allow arrow_cast, get_field, and union_extract to accept plain str arguments instead of requiring Expr wrappers. Also improve arrow_metadata test coverage and fix parameter shadowing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Accept str for key parameter in arrow_metadata for consistency Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add doctest examples and fix docstring style for new scalar functions Replace Args/Returns sections with doctest Examples blocks for arrow_metadata, get_field, union_extract, union_tag, and version to match existing codebase conventions. Simplify row to alias-style docstring with See Also reference. Document that arrow_cast accepts both str and Expr for data_type. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Support pyarrow DataType in arrow_cast Allow arrow_cast to accept a pyarrow DataType in addition to str and Expr. The DataType is converted to its string representation before being passed to DataFusion. Adds test coverage for the new input type. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Document bracket syntax shorthand in get_field docstring Note that expr["field"] is a convenient alternative when the field name is a static string, and get_field is needed for dynamic expressions. Add a second doctest example showing the bracket syntax. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Fix arrow_cast with pyarrow DataType by delegating to Expr.cast Use the existing Rust-side PyArrowType<DataType> conversion via Expr.cast() instead of str() which produces pyarrow type names that DataFusion does not recognize. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Clarify when to use arrow_cast vs Expr.cast in docstring Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 99bc960 commit d07fdb3

File tree

3 files changed

+296
-9
lines changed

3 files changed

+296
-9
lines changed

crates/core/src/functions.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,29 @@ expr_fn_vec!(named_struct);
695695
expr_fn!(from_unixtime, unixtime);
696696
expr_fn!(arrow_typeof, arg_1);
697697
expr_fn!(arrow_cast, arg_1 datatype);
698+
expr_fn_vec!(arrow_metadata);
699+
expr_fn!(union_tag, arg1);
698700
expr_fn!(random);
699701

702+
#[pyfunction]
703+
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
704+
functions::core::get_field()
705+
.call(vec![expr.into(), name.into()])
706+
.into()
707+
}
708+
709+
#[pyfunction]
710+
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
711+
functions::core::union_extract()
712+
.call(vec![union_expr.into(), field_name.into()])
713+
.into()
714+
}
715+
716+
#[pyfunction]
717+
fn version() -> PyExpr {
718+
functions::core::version().call(vec![]).into()
719+
}
720+
700721
// Array Functions
701722
array_fn!(array_append, array element);
702723
array_fn!(array_to_string, array delimiter);
@@ -1014,6 +1035,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
10141035
m.add_wrapped(wrap_pyfunction!(array_agg))?;
10151036
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
10161037
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
1038+
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
10171039
m.add_wrapped(wrap_pyfunction!(ascii))?;
10181040
m.add_wrapped(wrap_pyfunction!(asin))?;
10191041
m.add_wrapped(wrap_pyfunction!(asinh))?;
@@ -1142,6 +1164,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
11421164
m.add_wrapped(wrap_pyfunction!(trim))?;
11431165
m.add_wrapped(wrap_pyfunction!(trunc))?;
11441166
m.add_wrapped(wrap_pyfunction!(upper))?;
1167+
m.add_wrapped(wrap_pyfunction!(get_field))?;
1168+
m.add_wrapped(wrap_pyfunction!(union_extract))?;
1169+
m.add_wrapped(wrap_pyfunction!(union_tag))?;
1170+
m.add_wrapped(wrap_pyfunction!(version))?;
11451171
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
11461172
m.add_wrapped(wrap_pyfunction!(var_pop))?;
11471173
m.add_wrapped(wrap_pyfunction!(var_sample))?;

python/datafusion/functions.py

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
"arrays_overlap",
9999
"arrays_zip",
100100
"arrow_cast",
101+
"arrow_metadata",
101102
"arrow_typeof",
102103
"ascii",
103104
"asin",
@@ -163,6 +164,7 @@
163164
"gcd",
164165
"gen_series",
165166
"generate_series",
167+
"get_field",
166168
"greatest",
167169
"ifnull",
168170
"in_list",
@@ -280,6 +282,7 @@
280282
"reverse",
281283
"right",
282284
"round",
285+
"row",
283286
"row_number",
284287
"rpad",
285288
"rtrim",
@@ -322,12 +325,15 @@
322325
"translate",
323326
"trim",
324327
"trunc",
328+
"union_extract",
329+
"union_tag",
325330
"upper",
326331
"uuid",
327332
"var",
328333
"var_pop",
329334
"var_samp",
330335
"var_sample",
336+
"version",
331337
"when",
332338
# Window Functions
333339
"window",
@@ -2628,22 +2634,184 @@ def arrow_typeof(arg: Expr) -> Expr:
26282634
return Expr(f.arrow_typeof(arg.expr))
26292635

26302636

2631-
def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
2637+
def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
26322638
"""Casts an expression to a specified data type.
26332639
2640+
The ``data_type`` can be a string, a ``pyarrow.DataType``, or an
2641+
``Expr``. For simple types, :py:meth:`Expr.cast()
2642+
<datafusion.expr.Expr.cast>` is more concise
2643+
(e.g., ``col("a").cast(pa.float64())``). Use ``arrow_cast`` when
2644+
you want to specify the target type as a string using DataFusion's
2645+
type syntax, which can be more readable for complex types like
2646+
``"Timestamp(Nanosecond, None)"``.
2647+
26342648
Examples:
26352649
>>> ctx = dfn.SessionContext()
26362650
>>> df = ctx.from_pydict({"a": [1]})
2637-
>>> data_type = dfn.string_literal("Float64")
26382651
>>> result = df.select(
2639-
... dfn.functions.arrow_cast(dfn.col("a"), data_type).alias("c")
2652+
... dfn.functions.arrow_cast(dfn.col("a"), "Float64").alias("c")
2653+
... )
2654+
>>> result.collect_column("c")[0].as_py()
2655+
1.0
2656+
2657+
>>> import pyarrow as pa
2658+
>>> result = df.select(
2659+
... dfn.functions.arrow_cast(
2660+
... dfn.col("a"), data_type=pa.float64()
2661+
... ).alias("c")
26402662
... )
26412663
>>> result.collect_column("c")[0].as_py()
26422664
1.0
26432665
"""
2666+
if isinstance(data_type, pa.DataType):
2667+
return expr.cast(data_type)
2668+
if isinstance(data_type, str):
2669+
data_type = Expr.string_literal(data_type)
26442670
return Expr(f.arrow_cast(expr.expr, data_type.expr))
26452671

26462672

2673+
def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
2674+
"""Returns the metadata of the input expression.
2675+
2676+
If called with one argument, returns a Map of all metadata key-value pairs.
2677+
If called with two arguments, returns the value for the specified metadata key.
2678+
2679+
Examples:
2680+
>>> import pyarrow as pa
2681+
>>> field = pa.field("val", pa.int64(), metadata={"k": "v"})
2682+
>>> schema = pa.schema([field])
2683+
>>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema)
2684+
>>> ctx = dfn.SessionContext()
2685+
>>> df = ctx.create_dataframe([[batch]])
2686+
>>> result = df.select(
2687+
... dfn.functions.arrow_metadata(dfn.col("val")).alias("meta")
2688+
... )
2689+
>>> ("k", "v") in result.collect_column("meta")[0].as_py()
2690+
True
2691+
2692+
>>> result = df.select(
2693+
... dfn.functions.arrow_metadata(
2694+
... dfn.col("val"), key="k"
2695+
... ).alias("meta_val")
2696+
... )
2697+
>>> result.collect_column("meta_val")[0].as_py()
2698+
'v'
2699+
"""
2700+
if key is None:
2701+
return Expr(f.arrow_metadata(expr.expr))
2702+
if isinstance(key, str):
2703+
key = Expr.string_literal(key)
2704+
return Expr(f.arrow_metadata(expr.expr, key.expr))
2705+
2706+
2707+
def get_field(expr: Expr, name: Expr | str) -> Expr:
2708+
"""Extracts a field from a struct or map by name.
2709+
2710+
When the field name is a static string, the bracket operator
2711+
``expr["field"]`` is a convenient shorthand. Use ``get_field``
2712+
when the field name is a dynamic expression.
2713+
2714+
Examples:
2715+
>>> ctx = dfn.SessionContext()
2716+
>>> df = ctx.from_pydict({"a": [1], "b": [2]})
2717+
>>> df = df.with_column(
2718+
... "s",
2719+
... dfn.functions.named_struct(
2720+
... [("x", dfn.col("a")), ("y", dfn.col("b"))]
2721+
... ),
2722+
... )
2723+
>>> result = df.select(
2724+
... dfn.functions.get_field(dfn.col("s"), "x").alias("x_val")
2725+
... )
2726+
>>> result.collect_column("x_val")[0].as_py()
2727+
1
2728+
2729+
Equivalent using bracket syntax:
2730+
2731+
>>> result = df.select(
2732+
... dfn.col("s")["x"].alias("x_val")
2733+
... )
2734+
>>> result.collect_column("x_val")[0].as_py()
2735+
1
2736+
"""
2737+
if isinstance(name, str):
2738+
name = Expr.string_literal(name)
2739+
return Expr(f.get_field(expr.expr, name.expr))
2740+
2741+
2742+
def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr:
2743+
"""Extracts a value from a union type by field name.
2744+
2745+
Returns the value of the named field if it is the currently selected
2746+
variant, otherwise returns NULL.
2747+
2748+
Examples:
2749+
>>> import pyarrow as pa
2750+
>>> ctx = dfn.SessionContext()
2751+
>>> types = pa.array([0, 1, 0], type=pa.int8())
2752+
>>> offsets = pa.array([0, 0, 1], type=pa.int32())
2753+
>>> arr = pa.UnionArray.from_dense(
2754+
... types, offsets, [pa.array([1, 2]), pa.array(["hi"])],
2755+
... ["int", "str"], [0, 1],
2756+
... )
2757+
>>> batch = pa.RecordBatch.from_arrays([arr], names=["u"])
2758+
>>> df = ctx.create_dataframe([[batch]])
2759+
>>> result = df.select(
2760+
... dfn.functions.union_extract(dfn.col("u"), "int").alias("val")
2761+
... )
2762+
>>> result.collect_column("val").to_pylist()
2763+
[1, None, 2]
2764+
"""
2765+
if isinstance(field_name, str):
2766+
field_name = Expr.string_literal(field_name)
2767+
return Expr(f.union_extract(union_expr.expr, field_name.expr))
2768+
2769+
2770+
def union_tag(union_expr: Expr) -> Expr:
2771+
"""Returns the tag (active field name) of a union type.
2772+
2773+
Examples:
2774+
>>> import pyarrow as pa
2775+
>>> ctx = dfn.SessionContext()
2776+
>>> types = pa.array([0, 1, 0], type=pa.int8())
2777+
>>> offsets = pa.array([0, 0, 1], type=pa.int32())
2778+
>>> arr = pa.UnionArray.from_dense(
2779+
... types, offsets, [pa.array([1, 2]), pa.array(["hi"])],
2780+
... ["int", "str"], [0, 1],
2781+
... )
2782+
>>> batch = pa.RecordBatch.from_arrays([arr], names=["u"])
2783+
>>> df = ctx.create_dataframe([[batch]])
2784+
>>> result = df.select(
2785+
... dfn.functions.union_tag(dfn.col("u")).alias("tag")
2786+
... )
2787+
>>> result.collect_column("tag").to_pylist()
2788+
['int', 'str', 'int']
2789+
"""
2790+
return Expr(f.union_tag(union_expr.expr))
2791+
2792+
2793+
def version() -> Expr:
2794+
"""Returns the DataFusion version string.
2795+
2796+
Examples:
2797+
>>> ctx = dfn.SessionContext()
2798+
>>> df = ctx.empty_table()
2799+
>>> result = df.select(dfn.functions.version().alias("v"))
2800+
>>> "Apache DataFusion" in result.collect_column("v")[0].as_py()
2801+
True
2802+
"""
2803+
return Expr(f.version())
2804+
2805+
2806+
def row(*args: Expr) -> Expr:
2807+
"""Returns a struct with the given arguments.
2808+
2809+
See Also:
2810+
This is an alias for :py:func:`struct`.
2811+
"""
2812+
return struct(*args)
2813+
2814+
26472815
def random() -> Expr:
26482816
"""Returns a random value in the range ``0.0 <= x < 1.0``.
26492817

0 commit comments

Comments
 (0)