diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index fefe14b3e..3f07da95b 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -494,6 +494,8 @@ expr_fn!(length, string); expr_fn!(char_length, string); expr_fn!(chr, arg, "Returns the character with the given code."); expr_fn_vec!(coalesce); +expr_fn_vec!(greatest); +expr_fn_vec!(least); expr_fn!( contains, string search_str, @@ -548,6 +550,11 @@ expr_fn!( x y, "Returns x if x is not NULL otherwise returns y." ); +expr_fn!( + nvl2, + x y z, + "Returns y if x is not NULL; otherwise returns z." +); expr_fn!(nullif, arg_1 arg_2); expr_fn!( octet_length, @@ -989,6 +996,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; + m.add_wrapped(wrap_pyfunction!(greatest))?; // m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; @@ -996,6 +1004,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(iszero))?; m.add_wrapped(wrap_pyfunction!(levenshtein))?; m.add_wrapped(wrap_pyfunction!(lcm))?; + m.add_wrapped(wrap_pyfunction!(least))?; m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; @@ -1013,6 +1022,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(named_struct))?; m.add_wrapped(wrap_pyfunction!(nanvl))?; m.add_wrapped(wrap_pyfunction!(nvl))?; + m.add_wrapped(wrap_pyfunction!(nvl2))?; m.add_wrapped(wrap_pyfunction!(now))?; m.add_wrapped(wrap_pyfunction!(nullif))?; m.add_wrapped(wrap_pyfunction!(octet_length))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 2ef2f0473..f1ea3d256 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -152,6 +152,8 @@ "floor", "from_unixtime", "gcd", + "greatest", + "ifnull", "in_list", "initcap", "isnan", @@ -160,6 +162,7 @@ "last_value", "lcm", "lead", + "least", "left", "length", "levenshtein", @@ -216,6 +219,7 @@ "ntile", "nullif", "nvl", + "nvl2", "octet_length", "order_by", "overlay", @@ -1045,6 +1049,34 @@ def gcd(x: Expr, y: Expr) -> Expr: return Expr(f.gcd(x.expr, y.expr)) +def greatest(*args: Expr) -> Expr: + """Returns the greatest value from a list of expressions. + + Returns NULL if all expressions are NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 3], "b": [2, 1]}) + >>> result = df.select( + ... dfn.functions.greatest(dfn.col("a"), dfn.col("b")).alias("greatest")) + >>> result.collect_column("greatest")[0].as_py() + 2 + >>> result.collect_column("greatest")[1].as_py() + 3 + """ + exprs = [arg.expr for arg in args] + return Expr(f.greatest(*exprs)) + + +def ifnull(x: Expr, y: Expr) -> Expr: + """Returns ``x`` if ``x`` is not NULL. Otherwise returns ``y``. + + See Also: + This is an alias for :py:func:`nvl`. + """ + return nvl(x, y) + + def initcap(string: Expr) -> Expr: """Set the initial letter of each word to capital. @@ -1098,6 +1130,25 @@ def lcm(x: Expr, y: Expr) -> Expr: return Expr(f.lcm(x.expr, y.expr)) +def least(*args: Expr) -> Expr: + """Returns the least value from a list of expressions. + + Returns NULL if all expressions are NULL. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 3], "b": [2, 1]}) + >>> result = df.select( + ... dfn.functions.least(dfn.col("a"), dfn.col("b")).alias("least")) + >>> result.collect_column("least")[0].as_py() + 1 + >>> result.collect_column("least")[1].as_py() + 1 + """ + exprs = [arg.expr for arg in args] + return Expr(f.least(*exprs)) + + def left(string: Expr, n: Expr) -> Expr: """Returns the first ``n`` characters in the ``string``. @@ -1282,6 +1333,24 @@ def nvl(x: Expr, y: Expr) -> Expr: return Expr(f.nvl(x.expr, y.expr)) +def nvl2(x: Expr, y: Expr, z: Expr) -> Expr: + """Returns ``y`` if ``x`` is not NULL. Otherwise returns ``z``. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [None, 1], "b": [10, 20], "c": [30, 40]}) + >>> result = df.select( + ... dfn.functions.nvl2( + ... dfn.col("a"), dfn.col("b"), dfn.col("c")).alias("nvl2") + ... ) + >>> result.collect_column("nvl2")[0].as_py() + 30 + >>> result.collect_column("nvl2")[1].as_py() + 20 + """ + return Expr(f.nvl2(x.expr, y.expr, z.expr)) + + def octet_length(arg: Expr) -> Expr: """Returns the number of bytes of a string. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index db141fbe0..74fcbffb4 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1410,62 +1410,253 @@ def test_alias_with_metadata(df): assert df.schema().field("b").metadata == {b"key": b"value"} -def test_coalesce(df): - # Create a DataFrame with null values +@pytest.fixture +def df_with_nulls(): ctx = SessionContext() + # Rows: + # 0: both values present + # 1: a/d/h/k null, b/e/i/l present + # 2: a/d/h/k present, b/e/i/l null + # 3: all null batch = pa.RecordBatch.from_arrays( [ - pa.array(["Hello", None, "!"]), # string column with null - pa.array([4, None, 6]), # integer column with null - pa.array(["hello ", None, " !"]), # string column with null + pa.array([1, None, 3, None], type=pa.int64()), + pa.array([5, 10, None, None], type=pa.int64()), + pa.array([20, 30, 40, None], type=pa.int64()), + pa.array(["apple", None, "cherry", None], type=pa.utf8()), + pa.array(["banana", "date", None, None], type=pa.utf8()), + pa.array(["x", "y", "z", None], type=pa.utf8()), pa.array( [ - datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), None, - datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), - ] - ), # datetime with null - pa.array([False, None, True]), # boolean column with null + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + pa.array( + [ + datetime(2022, 7, 4, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + None, + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + pa.array([True, None, False, None], type=pa.bool_()), + pa.array([False, True, None, None], type=pa.bool_()), ], - names=["a", "b", "c", "d", "e"], - ) - df_with_nulls = ctx.create_dataframe([[batch]]) - - # Test coalesce with different data types - result_df = df_with_nulls.select( - f.coalesce(column("a"), literal("default")).alias("a_coalesced"), - f.coalesce(column("b"), literal(0)).alias("b_coalesced"), - f.coalesce(column("c"), literal("default")).alias("c_coalesced"), - f.coalesce(column("d"), literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ))).alias( - "d_coalesced" - ), - f.coalesce(column("e"), literal(value=False)).alias("e_coalesced"), + names=["a", "b", "c", "d", "e", "g", "h", "i", "k", "l"], ) + return ctx.create_dataframe([[batch]]) - result = result_df.collect()[0] - # Verify results - assert result.column(0) == pa.array( - ["Hello", "default", "!"], type=pa.string_view() - ) - assert result.column(1) == pa.array([4, 0, 6], type=pa.int64()) - assert result.column(2) == pa.array( - ["hello ", "default", " !"], type=pa.string_view() - ) - assert result.column(3).to_pylist() == [ - datetime(2022, 12, 31, tzinfo=DEFAULT_TZ), - datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), - datetime(2020, 7, 2, tzinfo=DEFAULT_TZ), - ] - assert result.column(4) == pa.array([False, False, True], type=pa.bool_()) - - # Test multiple arguments - result_df = df_with_nulls.select( - f.coalesce(column("a"), literal(None), literal("fallback")).alias( - "multi_coalesce" - ) - ) - result = result_df.collect()[0] - assert result.column(0) == pa.array( - ["Hello", "fallback", "!"], type=pa.string_view() - ) +@pytest.mark.parametrize( + ("expr", "expected"), + [ + pytest.param( + f.greatest(column("a"), column("b")), + pa.array([5, 10, 3, None], type=pa.int64()), + id="greatest_int", + ), + pytest.param( + f.greatest(column("d"), column("e")), + pa.array(["banana", "date", "cherry", None], type=pa.utf8()), + id="greatest_str", + ), + pytest.param( + f.least(column("a"), column("b")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="least_int", + ), + pytest.param( + f.least(column("d"), column("e")), + pa.array(["apple", "date", "cherry", None], type=pa.utf8()), + id="least_str", + ), + pytest.param( + f.coalesce(column("a"), column("b"), column("c")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="coalesce_int", + ), + pytest.param( + f.coalesce(column("d"), column("e"), column("g")), + pa.array(["apple", "date", "cherry", None], type=pa.utf8()), + id="coalesce_str", + ), + pytest.param( + f.nvl(column("a"), column("c")), + pa.array([1, 30, 3, None], type=pa.int64()), + id="nvl_int", + ), + pytest.param( + f.nvl(column("d"), column("g")), + pa.array(["apple", "y", "cherry", None], type=pa.utf8()), + id="nvl_str", + ), + pytest.param( + f.ifnull(column("a"), column("c")), + pa.array([1, 30, 3, None], type=pa.int64()), + id="ifnull_int", + ), + pytest.param( + f.ifnull(column("d"), column("g")), + pa.array(["apple", "y", "cherry", None], type=pa.utf8()), + id="ifnull_str", + ), + pytest.param( + f.nvl2(column("a"), column("b"), column("c")), + pa.array([5, 30, None, None], type=pa.int64()), + id="nvl2_int", + ), + pytest.param( + f.nvl2(column("d"), column("e"), column("g")), + pa.array(["banana", "y", None, None], type=pa.utf8()), + id="nvl2_str", + ), + pytest.param( + f.nullif(column("a"), column("b")), + pa.array([1, None, 3, None], type=pa.int64()), + id="nullif_int", + ), + pytest.param( + f.nullif(column("d"), column("e")), + pa.array(["apple", None, "cherry", None], type=pa.utf8()), + id="nullif_str", + ), + pytest.param( + f.nullif(column("a"), literal(1)), + pa.array([None, None, 3, None], type=pa.int64()), + id="nullif_equal_values", + ), + pytest.param( + f.greatest(column("a"), column("b"), column("c")), + pa.array([20, 30, 40, None], type=pa.int64()), + id="greatest_variadic", + ), + pytest.param( + f.least(column("a"), column("b"), column("c")), + pa.array([1, 10, 3, None], type=pa.int64()), + id="least_variadic", + ), + pytest.param( + f.greatest(column("a"), literal(2)), + pa.array([2, 2, 3, 2], type=pa.int64()), + id="greatest_literal", + ), + pytest.param( + f.least(column("a"), literal(2)), + pa.array([1, 2, 2, 2], type=pa.int64()), + id="least_literal", + ), + pytest.param( + f.coalesce(column("a"), literal(0)), + pa.array([1, 0, 3, 0], type=pa.int64()), + id="coalesce_literal_int", + ), + pytest.param( + f.coalesce(column("d"), literal("default")), + pa.array(["apple", "default", "cherry", "default"], type=pa.string_view()), + id="coalesce_literal_str", + ), + pytest.param( + f.nvl(column("a"), literal(99)), + pa.array([1, 99, 3, 99], type=pa.int64()), + id="nvl_literal", + ), + pytest.param( + f.ifnull(column("d"), literal("unknown")), + pa.array(["apple", "unknown", "cherry", "unknown"], type=pa.string_view()), + id="ifnull_literal", + ), + pytest.param( + f.nvl2(column("a"), literal(1), literal(0)), + pa.array([1, 0, 1, 0], type=pa.int64()), + id="nvl2_literal", + ), + pytest.param( + f.greatest(column("h"), column("i")), + pa.array( + [ + datetime(2022, 7, 4, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="greatest_datetime", + ), + pytest.param( + f.least(column("h"), column("i")), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="least_datetime", + ), + pytest.param( + f.coalesce(column("h"), column("i")), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2023, 12, 25, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + None, + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="coalesce_datetime", + ), + pytest.param( + f.nvl(column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="nvl_bool", + ), + pytest.param( + f.coalesce(column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="coalesce_bool", + ), + pytest.param( + f.nvl2(column("k"), column("k"), column("l")), + pa.array([True, True, False, None], type=pa.bool_()), + id="nvl2_bool", + ), + pytest.param( + f.coalesce( + column("h"), + literal(datetime(2000, 1, 1, tzinfo=DEFAULT_TZ)), + ), + pa.array( + [ + datetime(2020, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), + datetime(2025, 6, 15, tzinfo=DEFAULT_TZ), + datetime(2000, 1, 1, tzinfo=DEFAULT_TZ), + ], + type=pa.timestamp("us", tz="UTC"), + ), + id="coalesce_literal_datetime", + ), + pytest.param( + f.coalesce(column("k"), literal(value=False)), + pa.array([True, False, False, False], type=pa.bool_()), + id="coalesce_literal_bool", + ), + pytest.param( + f.coalesce(column("a"), literal(None), literal(99)), + pa.array([1, 99, 3, 99], type=pa.int64()), + id="coalesce_skip_null_literal", + ), + ], +) +def test_conditional_functions(df_with_nulls, expr, expected): + result = df_with_nulls.select(expr.alias("result")).collect()[0] + assert result.column(0) == expected