-
Notifications
You must be signed in to change notification settings - Fork 150
Add missing SessionContext utility methods #1475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0cfcbb4
f72830f
bb888f4
c7f47f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -63,7 +63,8 @@ | |||||||||
| import polars as pl # type: ignore[import] | ||||||||||
|
|
||||||||||
| from datafusion.catalog import CatalogProvider, Table | ||||||||||
| from datafusion.expr import SortKey | ||||||||||
| from datafusion.common import DFSchema | ||||||||||
| from datafusion.expr import Expr, SortKey | ||||||||||
| from datafusion.plan import ExecutionPlan, LogicalPlan | ||||||||||
| from datafusion.user_defined import ( | ||||||||||
| AggregateUDF, | ||||||||||
|
|
@@ -1141,6 +1142,120 @@ def session_id(self) -> str: | |||||||||
| """Return an id that uniquely identifies this :py:class:`SessionContext`.""" | ||||||||||
| return self.ctx.session_id() | ||||||||||
|
|
||||||||||
| def session_start_time(self) -> str: | ||||||||||
| """Return the session start time as an RFC 3339 formatted string. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> start_time = ctx.session_start_time() | ||||||||||
| >>> assert "T" in start_time # RFC 3339 contains a 'T' separator | ||||||||||
| """ | ||||||||||
| return self.ctx.session_start_time() | ||||||||||
|
|
||||||||||
| def enable_ident_normalization(self) -> bool: | ||||||||||
| """Return whether identifier normalization (lowercasing) is enabled. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> assert isinstance(ctx.enable_ident_normalization(), bool) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing here:
Suggested change
|
||||||||||
| """ | ||||||||||
| return self.ctx.enable_ident_normalization() | ||||||||||
|
|
||||||||||
timsaucer marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| def parse_sql_expr(self, sql: str, schema: DFSchema) -> Expr: | ||||||||||
| """Parse a SQL expression string into a logical expression. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| sql: SQL expression string. | ||||||||||
| schema: Schema to use for resolving column references. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| Parsed expression. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> from datafusion.common import DFSchema | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> schema = DFSchema.empty() | ||||||||||
| >>> expr = ctx.parse_sql_expr("1 + 2", schema) | ||||||||||
| >>> assert "Int64(1) + Int64(2)" in str(expr) | ||||||||||
|
Comment on lines
+1178
to
+1179
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| """ | ||||||||||
| from datafusion.expr import Expr # noqa: PLC0415 | ||||||||||
|
|
||||||||||
| return Expr(self.ctx.parse_sql_expr(sql, schema)) | ||||||||||
|
Comment on lines
+1181
to
+1183
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could remove the import and the wrapping with |
||||||||||
|
|
||||||||||
| def execute_logical_plan(self, plan: LogicalPlan) -> DataFrame: | ||||||||||
| """Execute a :py:class:`~datafusion.plan.LogicalPlan` and return a DataFrame. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| plan: Logical plan to execute. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| DataFrame resulting from the execution. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> df = ctx.from_pydict({"a": [1, 2, 3]}) | ||||||||||
| >>> plan = df.logical_plan() | ||||||||||
| >>> df2 = ctx.execute_logical_plan(plan) | ||||||||||
| >>> df2.collect()[0].column(0) | ||||||||||
| <pyarrow.lib.Int64Array object at ...> | ||||||||||
| [ | ||||||||||
| 1, | ||||||||||
| 2, | ||||||||||
| 3 | ||||||||||
| ] | ||||||||||
| """ | ||||||||||
| return DataFrame(self.ctx.execute_logical_plan(plan._raw_plan)) | ||||||||||
|
|
||||||||||
| def refresh_catalogs(self) -> None: | ||||||||||
| """Refresh catalog metadata. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> ctx.refresh_catalogs() | ||||||||||
| """ | ||||||||||
| self.ctx.refresh_catalogs() | ||||||||||
|
|
||||||||||
| def remove_optimizer_rule(self, name: str) -> bool: | ||||||||||
| """Remove an optimizer rule by name. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| name: Name of the optimizer rule to remove. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| True if a rule with the given name was found and removed. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> ctx.remove_optimizer_rule("nonexistent_rule") | ||||||||||
| False | ||||||||||
| """ | ||||||||||
| return self.ctx.remove_optimizer_rule(name) | ||||||||||
|
|
||||||||||
| def table_provider(self, name: str) -> Table: | ||||||||||
| """Return the :py:class:`~datafusion.catalog.Table` for the given table name. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| name: Name of the table. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| The table provider. | ||||||||||
|
|
||||||||||
| Raises: | ||||||||||
| KeyError: If the table is not found. | ||||||||||
|
|
||||||||||
| Examples: | ||||||||||
| >>> import pyarrow as pa | ||||||||||
| >>> ctx = SessionContext() | ||||||||||
| >>> batch = pa.RecordBatch.from_pydict({"x": [1, 2]}) | ||||||||||
| >>> ctx.register_record_batches("my_table", [[batch]]) | ||||||||||
| >>> tbl = ctx.table_provider("my_table") | ||||||||||
| >>> tbl.schema | ||||||||||
| x: int64 | ||||||||||
| """ | ||||||||||
| from datafusion.catalog import Table # noqa: PLC0415 | ||||||||||
|
|
||||||||||
| return Table(self.ctx.table_provider(name)) | ||||||||||
|
Comment on lines
+1255
to
+1257
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here I think we can remove the |
||||||||||
|
|
||||||||||
| def read_json( | ||||||||||
| self, | ||||||||||
| path: str | pathlib.Path, | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -551,6 +551,53 @@ def test_table_not_found(ctx): | |||||||||||
| ctx.table(f"not-found-{uuid4()}") | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_session_start_time(ctx): | ||||||||||||
| st = ctx.session_start_time() | ||||||||||||
| assert isinstance(st, str) | ||||||||||||
| assert "T" in st # RFC 3339 format | ||||||||||||
|
Comment on lines
+556
to
+557
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about this? The conversion should fail if the string is badly formatted.
Suggested change
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_enable_ident_normalization(ctx): | ||||||||||||
| result = ctx.enable_ident_normalization() | ||||||||||||
| assert isinstance(result, bool) | ||||||||||||
|
Comment on lines
+561
to
+562
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to change the value and check it.
Suggested change
Unrelated but the original method name is a bit misleading since it does not enable the flag, only returns the value. |
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_parse_sql_expr(ctx): | ||||||||||||
| from datafusion.common import DFSchema | ||||||||||||
|
|
||||||||||||
| schema = DFSchema.empty() | ||||||||||||
| expr = ctx.parse_sql_expr("1 + 2", schema) | ||||||||||||
| assert "Int64(1) + Int64(2)" in str(expr) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_execute_logical_plan(ctx): | ||||||||||||
| df = ctx.from_pydict({"a": [1, 2, 3]}) | ||||||||||||
| plan = df.logical_plan() | ||||||||||||
| df2 = ctx.execute_logical_plan(plan) | ||||||||||||
| result = df2.collect() | ||||||||||||
| assert result[0].column(0) == pa.array([1, 2, 3]) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_refresh_catalogs(ctx): | ||||||||||||
| ctx.refresh_catalogs() | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_remove_optimizer_rule(ctx): | ||||||||||||
| assert ctx.remove_optimizer_rule("nonexistent_rule") is False | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing with a rule that exists as well:
Suggested change
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_table_provider(ctx): | ||||||||||||
| batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) | ||||||||||||
| ctx.register_record_batches("provider_test", [[batch]]) | ||||||||||||
| tbl = ctx.table_provider("provider_test") | ||||||||||||
| assert tbl.schema == pa.schema([("x", pa.int64())]) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_table_provider_not_found(ctx): | ||||||||||||
| with pytest.raises(KeyError): | ||||||||||||
| ctx.table_provider("nonexistent_table") | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_read_json(ctx): | ||||||||||||
| path = pathlib.Path(__file__).parent.resolve() | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assert feels a little odd, what about showing a result?