diff --git a/deepnote_toolkit/sql/setup_statement_parser.py b/deepnote_toolkit/sql/setup_statement_parser.py new file mode 100644 index 0000000..cf2056e --- /dev/null +++ b/deepnote_toolkit/sql/setup_statement_parser.py @@ -0,0 +1,251 @@ +"""Parser that strips a strict leading run of session-setup statements +(``USE WAREHOUSE ...``, ``USE ROLE ...``, ``SET ...``, ``ALTER SESSION ...``) +off a rendered SQL query so they can be executed as `setup_statements` on the +same connection as the main query. +""" + +import re +from typing import Optional + +# Allowlist of leading session-setup statement keywords. Each entry is the +# tuple of consecutive keywords (case-insensitive, whitespace between). +_SETUP_PREFIXES: tuple[tuple[str, ...], ...] = ( + ("USE", "WAREHOUSE"), + ("USE", "DATABASE"), + ("USE", "SCHEMA"), + ("USE", "ROLE"), + ("USE", "SECONDARY", "ROLES"), + ("SET",), + ("ALTER", "SESSION"), +) + +# Placeholder pattern per param_style. JinjaSQL emits these for bound values. +_PLACEHOLDER_PATTERNS: dict[str, re.Pattern[str]] = { + "pyformat": re.compile(r"%\([^)]+\)s"), + "format": re.compile(r"%s"), + "named": re.compile(r":[A-Za-z_][A-Za-z0-9_]*"), + "numeric": re.compile(r":\d+"), + "qmark": re.compile(r"\?"), +} + + +class SetupStatementError(ValueError): + """Raised when a leading USE/SET/ALTER SESSION statement contains a + Jinja-bound value, which can't be passed as a SQL parameter to most + drivers (and isn't accepted by Snowflake's USE WAREHOUSE at all). + """ + + +def extract_setup_statements( + rendered_query: str, param_style: Optional[str] = None +) -> tuple[list[str], str]: + """Strip a strict leading run of session-setup statements off the input. + + Returns ``(setup_statements, remaining_query)``. ``setup_statements`` are + the trimmed statement bodies (no trailing ``;``); ``remaining_query`` is + the rest of the input from the first non-setup statement onwards. + + A statement is a setup-statement candidate iff its first non-whitespace, + non-comment tokens match one of the allowlisted prefixes (``USE WAREHOUSE``, + ``USE DATABASE``, ``USE SCHEMA``, ``USE ROLE``, ``USE SECONDARY ROLES``, + ``SET``, ``ALTER SESSION``). Comparison is case-insensitive. + + The leading run ends at the first statement whose prefix is not in the + allowlist; everything from there on is the remaining query. + + If the entire input is setup-only (no main query follows), the input is + returned unchanged with an empty list — callers fall through to today's + pandas-multi-statement behavior rather than silently swallowing the cell. + + If any candidate setup statement contains a bind placeholder for the + given ``param_style`` (outside quoted regions), raises + :class:`SetupStatementError`. The message points the caller at the + explicit ``setup_statements=`` kwarg. + """ + pos = 0 + n = len(rendered_query) + extracted_ranges: list[tuple[int, int]] = [] # (start, end_excl_semicolon) + + while pos < n: + new_pos = _skip_whitespace_and_comments(rendered_query, pos) + if new_pos >= n: + break + stmt_start = new_pos + + match_end = _match_setup_prefix(rendered_query, stmt_start) + if match_end is None: + break + + stmt_end = _find_unquoted_semicolon(rendered_query, match_end) + if stmt_end is None: + # No closing semicolon — would consume the rest of the cell as a + # setup statement. Don't extract; pass through unchanged. + return [], rendered_query + + extracted_ranges.append((stmt_start, stmt_end)) + pos = stmt_end + 1 + + if not extracted_ranges: + return [], rendered_query + + # If nothing of substance follows the last extracted setup statement, + # the whole cell was setup-only — pass through unchanged so the user + # sees the original failure mode rather than a silent no-op. + tail = rendered_query[pos:] + if _skip_whitespace_and_comments(tail, 0) >= len(tail): + return [], rendered_query + + setup_statements = [ + rendered_query[start:end].strip() for (start, end) in extracted_ranges + ] + + placeholder_re = _PLACEHOLDER_PATTERNS.get(param_style or "pyformat") + if placeholder_re is not None: + for stmt in setup_statements: + if _has_match_outside_quotes(stmt, placeholder_re): + raise SetupStatementError( + "Templated values in leading USE/SET/ALTER SESSION " + "statements aren't supported. Either inline the value " + "(e.g. `USE WAREHOUSE prod`) or pass the dynamic " + "statement via the setup_statements= kwarg." + ) + + remaining_query = rendered_query[pos:] + return setup_statements, remaining_query + + +def _skip_whitespace_and_comments(s: str, pos: int) -> int: + n = len(s) + while pos < n: + c = s[pos] + if c.isspace(): + pos += 1 + elif c == "-" and pos + 1 < n and s[pos + 1] == "-": + nl = s.find("\n", pos + 2) + pos = nl + 1 if nl != -1 else n + elif c == "/" and pos + 1 < n and s[pos + 1] == "*": + close = s.find("*/", pos + 2) + pos = close + 2 if close != -1 else n + else: + break + return pos + + +def _match_setup_prefix(s: str, pos: int) -> Optional[int]: + """If ``s[pos:]`` starts (case-insensitive) with one of the allowlist + prefixes followed by a word boundary, return the index right after the + prefix. Otherwise None. Whitespace between consecutive keywords is allowed. + """ + n = len(s) + for prefix in _SETUP_PREFIXES: + cur = pos + ok = True + for i, word in enumerate(prefix): + if i > 0: + cur = _skip_inline_whitespace(s, cur) + wl = len(word) + if cur + wl > n or s[cur : cur + wl].upper() != word: + ok = False + break + cur += wl + if not ok: + continue + # Must end on a word boundary so e.g. "SETUP" doesn't match "SET". + if cur < n and (s[cur].isalnum() or s[cur] == "_"): + continue + return cur + return None + + +def _skip_inline_whitespace(s: str, pos: int) -> int: + n = len(s) + while pos < n and s[pos] in " \t\r\n": + pos += 1 + return pos + + +def _find_unquoted_semicolon(s: str, pos: int) -> Optional[int]: + n = len(s) + while pos < n: + c = s[pos] + if c == ";": + return pos + elif c == "'": + pos = _skip_single_quoted(s, pos) + elif c == '"': + pos = _skip_double_quoted(s, pos) + elif c == "$" and pos + 1 < n and s[pos + 1] == "$": + pos = _skip_dollar_quoted(s, pos) + elif c == "-" and pos + 1 < n and s[pos + 1] == "-": + nl = s.find("\n", pos + 2) + pos = nl + 1 if nl != -1 else n + elif c == "/" and pos + 1 < n and s[pos + 1] == "*": + close = s.find("*/", pos + 2) + pos = close + 2 if close != -1 else n + else: + pos += 1 + return None + + +def _skip_single_quoted(s: str, pos: int) -> int: + """``pos`` is at the opening ``'``. Returns position past the closing one, + treating doubled ``''`` as an escaped quote.""" + n = len(s) + pos += 1 + while pos < n: + if s[pos] == "'": + if pos + 1 < n and s[pos + 1] == "'": + pos += 2 + else: + return pos + 1 + else: + pos += 1 + return n + + +def _skip_double_quoted(s: str, pos: int) -> int: + n = len(s) + pos += 1 + while pos < n: + if s[pos] == '"': + if pos + 1 < n and s[pos + 1] == '"': + pos += 2 + else: + return pos + 1 + else: + pos += 1 + return n + + +def _skip_dollar_quoted(s: str, pos: int) -> int: + """``pos`` is at the first ``$`` of ``$$``. Returns position past the + closing ``$$`` (or EOF if missing).""" + pos += 2 + close = s.find("$$", pos) + return close + 2 if close != -1 else len(s) + + +def _has_match_outside_quotes(s: str, pattern: re.Pattern[str]) -> bool: + """Check whether *pattern* matches anywhere in *s* outside of SQL string + literals, double-quoted identifiers, ``$$``-quoted strings, or comments.""" + pos = 0 + n = len(s) + while pos < n: + c = s[pos] + if c == "'": + pos = _skip_single_quoted(s, pos) + elif c == '"': + pos = _skip_double_quoted(s, pos) + elif c == "$" and pos + 1 < n and s[pos + 1] == "$": + pos = _skip_dollar_quoted(s, pos) + elif c == "-" and pos + 1 < n and s[pos + 1] == "-": + nl = s.find("\n", pos + 2) + pos = nl + 1 if nl != -1 else n + elif c == "/" and pos + 1 < n and s[pos + 1] == "*": + close = s.find("*/", pos + 2) + pos = close + 2 if close != -1 else n + else: + if pattern.match(s, pos): + return True + pos += 1 + return False diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 07b61fe..7b9a8dc 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -34,6 +34,7 @@ from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql from deepnote_toolkit.sql.jinjasql_utils import render_jinja_sql_template from deepnote_toolkit.sql.query_preview import DeepnoteQueryPreview +from deepnote_toolkit.sql.setup_statement_parser import extract_setup_statements from deepnote_toolkit.sql.sql_caching import get_sql_cache, upload_sql_cache from deepnote_toolkit.sql.sql_query_chaining import add_limit_clause, unchain_sql_query from deepnote_toolkit.sql.sql_utils import is_single_select_query @@ -102,6 +103,7 @@ def execute_sql_with_connection_json( audit_sql_comment="", sql_cache_mode="cache_disabled", return_variable_type="dataframe", + setup_statements=None, ): """ Executes a SQL query using the given connection JSON (string). @@ -111,6 +113,11 @@ def execute_sql_with_connection_json( :param sql_alchemy_json: String containing JSON with the connection details. Mandatory fields: url, params, param_style :param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write" + :param setup_statements: Optional list of raw SQL statements to run on the + same connection right before *template*. Use for session setup such as + ``USE WAREHOUSE abc`` whose effect must be visible to the main query. + Statements are not Jinja-rendered, parameter-bound, or audit-commented; + they are executed in order via ``connection.exec_driver_sql``. :return: Pandas dataframe with the result """ @@ -204,6 +211,20 @@ class ExecuteSqlError(Exception): if not compiled_query.strip(): return + # Strip a leading run of session-setup statements (USE WAREHOUSE, + # USE ROLE, SET ..., ALTER SESSION ...) off the compiled query and + # run them as setup_statements on the same connection as the main + # query. Explicit setup_statements from the caller are appended. + parsed_setup, compiled_query = extract_setup_statements( + compiled_query, param_style + ) + # query_preview_source mirrors the user-visible main query for cache + # key/preview purposes — keep it aligned with the post-extraction + # remainder. + if parsed_setup: + query_preview_source = compiled_query + final_setup_statements = parsed_setup + list(setup_statements or []) + if ( not is_single_select_query(compiled_query) and return_variable_type == "query_preview" @@ -221,6 +242,7 @@ class ExecuteSqlError(Exception): sql_cache_mode, return_variable_type, query_preview_source, + setup_statements=final_setup_statements, ) @@ -230,6 +252,7 @@ def execute_sql( audit_sql_comment="", sql_cache_mode="cache_disabled", return_variable_type="dataframe", + setup_statements=None, ): """ Wrapper around execute_sql_with_connection_json which reads the connection JSON from @@ -237,6 +260,7 @@ def execute_sql( :param template: Templated SQL :param sql_alchemy_json_env_var: Name of the environment variable containing the connection JSON :param sql_cache_mode: SQL caching setting for the query. Possible values: "cache_disabled", "always_write", "read_or_write" + :param setup_statements: See ``execute_sql_with_connection_json``. :return: Pandas dataframe with the result """ @@ -260,6 +284,7 @@ class ExecuteSqlError(Exception): audit_sql_comment=audit_sql_comment, sql_cache_mode=sql_cache_mode, return_variable_type=return_variable_type, + setup_statements=setup_statements, ) @@ -406,9 +431,14 @@ def _execute_sql_with_caching( sql_cache_mode, return_variable_type, query_preview_source, + setup_statements=None, ): # duckdb SQL is not cached, so we can skip the logic below for duckdb if requires_duckdb: + # DuckDB uses a process-wide singleton connection, so session state set + # by setup_statements naturally persists for the main query. + for stmt in setup_statements or []: + execute_duckdb_sql(stmt, {}) dataframe = execute_duckdb_sql(query, bind_params) # for Chained SQL we return the dataframe with the SQL source attached as DeepnoteQueryPreview object if return_variable_type == "query_preview": @@ -447,6 +477,7 @@ def _execute_sql_with_caching( cache_upload_url, return_variable_type, query_preview_source, # The original query before any transformations such as appending a LIMIT clause + setup_statements=setup_statements, ) @@ -478,6 +509,7 @@ def _query_data_source( cache_upload_url, return_variable_type, query_preview_source, + setup_statements=None, ): sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) @@ -491,7 +523,9 @@ def _query_data_source( ) try: - dataframe = _execute_sql_on_engine(engine, query, bind_params) + dataframe = _execute_sql_on_engine( + engine, query, bind_params, setup_statements=setup_statements + ) if dataframe is None: return None @@ -609,13 +643,19 @@ def _cancel_cursor(cursor: "DBAPICursor") -> None: pass # Best effort, ignore all errors -def _execute_sql_on_engine(engine, query, bind_params): +def _execute_sql_on_engine(engine, query, bind_params, setup_statements=None): """Run *query* on *engine* and return a DataFrame. Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection. For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute, we use the underlying connection. + When *setup_statements* is provided, each statement is executed on the + same DBAPI connection right before the main query so any session state + it sets (e.g. Snowflake ``USE WAREHOUSE``) is in effect when the main + query runs. Setup statements are issued via ``connection.exec_driver_sql`` + and any failure aborts the main query. + On exceptions (including KeyboardInterrupt from cell cancellation), all cursors created during execution are cancelled to stop running queries on the server. """ @@ -639,6 +679,16 @@ def _execute_sql_on_engine(engine, query, bind_params): ) with engine.begin() as connection: + # Run setup statements first on the same physical connection so any + # session state they set is visible to the main query below. + # Setup statements are session-control SQL (USE WAREHOUSE, USE ROLE, + # SET, ALTER SESSION) that cannot be parameter-bound — Snowflake and + # most other engines reject placeholders here. The contract on this + # function is that callers pass trusted SQL. + for stmt in setup_statements or []: + # nosemgrep: python.sqlalchemy.security.sqlalchemy-execute-raw-query.sqlalchemy-execute-raw-query + connection.exec_driver_sql(stmt) + # For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection if needs_raw_connection: tracking_connection = CursorTrackingDBAPIConnection(connection.connection) diff --git a/tests/unit/test_setup_statement_parser.py b/tests/unit/test_setup_statement_parser.py new file mode 100644 index 0000000..6d916f6 --- /dev/null +++ b/tests/unit/test_setup_statement_parser.py @@ -0,0 +1,189 @@ +import pytest + +from deepnote_toolkit.sql.setup_statement_parser import ( + SetupStatementError, + extract_setup_statements, +) + + +def test_extracts_single_use_warehouse(): + setup, remaining = extract_setup_statements("USE WAREHOUSE abc; SELECT 1") + assert setup == ["USE WAREHOUSE abc"] + assert remaining.strip() == "SELECT 1" + + +def test_extracts_multiple_setup_statements_in_order(): + setup, remaining = extract_setup_statements( + "USE WAREHOUSE abc; USE ROLE r; SELECT 1" + ) + assert setup == ["USE WAREHOUSE abc", "USE ROLE r"] + assert remaining.strip() == "SELECT 1" + + +def test_recognises_all_allowlist_prefixes(): + cell = ( + "USE WAREHOUSE w; USE DATABASE d; USE SCHEMA s; " + "USE ROLE r; USE SECONDARY ROLES ALL; " + "SET v = 'x'; ALTER SESSION SET TIMEZONE = 'UTC'; " + "SELECT 1" + ) + setup, remaining = extract_setup_statements(cell) + assert setup == [ + "USE WAREHOUSE w", + "USE DATABASE d", + "USE SCHEMA s", + "USE ROLE r", + "USE SECONDARY ROLES ALL", + "SET v = 'x'", + "ALTER SESSION SET TIMEZONE = 'UTC'", + ] + assert remaining.strip() == "SELECT 1" + + +def test_case_insensitive(): + setup, remaining = extract_setup_statements("use Warehouse abc; SeLeCt 1") + assert setup == ["use Warehouse abc"] + assert remaining.strip() == "SeLeCt 1" + + +def test_skips_leading_whitespace_and_newlines(): + setup, remaining = extract_setup_statements("\n\n USE WAREHOUSE abc;\nSELECT 1") + assert setup == ["USE WAREHOUSE abc"] + assert "SELECT 1" in remaining + + +def test_skips_leading_line_comments(): + setup, remaining = extract_setup_statements( + "-- pick the right wh\nUSE WAREHOUSE abc;\nSELECT 1" + ) + assert setup == ["USE WAREHOUSE abc"] + assert "SELECT 1" in remaining + + +def test_skips_leading_block_comments(): + setup, remaining = extract_setup_statements( + "/* setup */ USE WAREHOUSE abc;\nSELECT 1" + ) + assert setup == ["USE WAREHOUSE abc"] + assert "SELECT 1" in remaining + + +def test_skips_comments_between_setup_statements(): + setup, remaining = extract_setup_statements( + "USE WAREHOUSE abc;\n-- next\nUSE ROLE r;\nSELECT 1" + ) + assert setup == ["USE WAREHOUSE abc", "USE ROLE r"] + assert "SELECT 1" in remaining + + +def test_quoted_identifier_with_semicolon(): + setup, remaining = extract_setup_statements('USE WAREHOUSE "my;wh"; SELECT 1') + assert setup == ['USE WAREHOUSE "my;wh"'] + assert remaining.strip() == "SELECT 1" + + +def test_string_literal_with_semicolon(): + setup, remaining = extract_setup_statements("SET v = 'a;b'; SELECT 1") + assert setup == ["SET v = 'a;b'"] + assert remaining.strip() == "SELECT 1" + + +def test_dollar_quoted_string_with_semicolon(): + setup, remaining = extract_setup_statements("SET v = $$a;b$$; SELECT 1") + assert setup == ["SET v = $$a;b$$"] + assert remaining.strip() == "SELECT 1" + + +def test_doubled_inner_quotes_in_identifier(): + setup, remaining = extract_setup_statements('USE WAREHOUSE "a""b;c"; SELECT 1') + assert setup == ['USE WAREHOUSE "a""b;c"'] + assert remaining.strip() == "SELECT 1" + + +def test_no_extraction_when_first_statement_is_not_setup(): + setup, remaining = extract_setup_statements("SELECT * FROM t; USE WAREHOUSE abc") + assert setup == [] + assert remaining == "SELECT * FROM t; USE WAREHOUSE abc" + + +def test_extraction_stops_at_first_non_setup_statement(): + setup, remaining = extract_setup_statements( + "USE WAREHOUSE abc; SELECT 1; USE ROLE r" + ) + assert setup == ["USE WAREHOUSE abc"] + assert remaining.strip() == "SELECT 1; USE ROLE r" + + +def test_setup_only_cell_passes_through_unchanged(): + """If everything is setup with no main query the input is returned + unchanged so the user sees the original failure mode rather than a + silent no-op.""" + cell = "USE WAREHOUSE abc; USE ROLE r;" + setup, remaining = extract_setup_statements(cell) + assert setup == [] + assert remaining == cell + + +def test_setup_only_cell_with_trailing_comments_passes_through(): + cell = "USE WAREHOUSE abc;\n-- trailing comment\n" + setup, remaining = extract_setup_statements(cell) + assert setup == [] + assert remaining == cell + + +def test_no_setup_statements_returns_input_unchanged(): + cell = "SELECT 1" + setup, remaining = extract_setup_statements(cell) + assert setup == [] + assert remaining == cell + + +def test_unterminated_setup_statement_passes_through(): + """No closing ; on the leading USE — would otherwise consume the rest of + the cell as a setup statement; safer to not extract.""" + cell = "USE WAREHOUSE abc" # no semicolon + setup, remaining = extract_setup_statements(cell) + assert setup == [] + assert remaining == cell + + +def test_set_keyword_must_have_word_boundary(): + """`SETUP` must not match `SET`.""" + cell = "SETUP something_else; SELECT 1" + setup, remaining = extract_setup_statements(cell) + assert setup == [] + assert remaining == cell + + +@pytest.mark.parametrize( + "param_style,placeholder", + [ + ("pyformat", "%(p_0)s"), + ("format", "%s"), + ("named", ":p_0"), + ("numeric", ":1"), + ("qmark", "?"), + ], +) +def test_raises_when_setup_statement_contains_placeholder(param_style, placeholder): + """A templated value in a setup statement renders as a bind placeholder. + `connection.exec_driver_sql` doesn't bind; raise so the caller knows to + inline the value or pass it explicitly.""" + cell = f"USE WAREHOUSE {placeholder}; SELECT 1" + with pytest.raises(SetupStatementError, match="setup_statements="): + extract_setup_statements(cell, param_style) + + +def test_placeholder_inside_string_literal_does_not_trigger_error(): + """`%(x)s` inside a single-quoted string is literal text, not a placeholder.""" + setup, remaining = extract_setup_statements("SET v = '%(x)s'; SELECT 1", "pyformat") + assert setup == ["SET v = '%(x)s'"] + assert remaining.strip() == "SELECT 1" + + +def test_default_param_style_is_pyformat(): + """When ``param_style`` is None we still detect pyformat placeholders + because that's JinjaSQL's default.""" + cell = "USE WAREHOUSE %(p_0)s; SELECT 1" + with pytest.raises(SetupStatementError, match="setup_statements="): + extract_setup_statements(cell, None) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index a684077..2dd2431 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -131,6 +131,7 @@ def test_sql_executed_with_audit_comment_but_hash_calculated_without_it( mock.ANY, mock.ANY, mock.ANY, + setup_statements=[], ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -155,6 +156,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source): mock.ANY, "dataframe", mock.ANY, + setup_statements=[], ) # Test with explicit return_variable_type='query_preview' @@ -171,6 +173,7 @@ def test_return_variable_type_parameter(self, mocked_query_data_source): mock.ANY, "query_preview", mock.ANY, + setup_statements=[], ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -199,6 +202,7 @@ def test_query_preview_preserves_trailing_inline_comment( mock.ANY, "query_preview", mock.ANY, + setup_statements=[], ) @mock.patch("deepnote_toolkit.sql.sql_caching._generate_cache_key") @@ -234,6 +238,7 @@ def test_sql_executed_with_audit_comment_with_semicolon( mock.ANY, "dataframe", mock.ANY, + setup_statements=[], ) @mock.patch("deepnote_toolkit.sql.sql_execution._query_data_source") @@ -331,6 +336,189 @@ def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key( ) +class TestSetupStatementsAgainstRealSQLite(TestCase): + """End-to-end behavior of setup_statements against a real SQLite engine. + + SQLite ``TEMP TABLE``s are connection-scoped: they are only visible from + the same physical DBAPI connection that created them. So a successful + SELECT here proves the setup statement and the main query ran on the + same connection, not just that we called exec_driver_sql in the right + order. No mocks of pandas or SQLAlchemy. + """ + + def test_explicit_setup_statements_share_connection_with_main_query(self): + sql_alchemy_json = json.dumps( + { + "url": "sqlite:///:memory:", + "params": {}, + "param_style": "qmark", + } + ) + + result = execute_sql_with_connection_json( + "SELECT x FROM t ORDER BY x", + sql_alchemy_json, + setup_statements=[ + "CREATE TEMP TABLE t (x INTEGER)", + "INSERT INTO t VALUES (42)", + "INSERT INTO t VALUES (43)", + ], + ) + + assert result is not None + self.assertEqual(list(result["x"]), [42, 43]) + + def test_failing_setup_statement_aborts_main_query(self): + sql_alchemy_json = json.dumps( + { + "url": "sqlite:///:memory:", + "params": {}, + "param_style": "qmark", + } + ) + + with self.assertRaises(Exception): + execute_sql_with_connection_json( + "SELECT 1", + sql_alchemy_json, + setup_statements=["NOT VALID SQL AT ALL"], + ) + + def test_no_setup_statements_still_works(self): + sql_alchemy_json = json.dumps( + { + "url": "sqlite:///:memory:", + "params": {}, + "param_style": "qmark", + } + ) + + result = execute_sql_with_connection_json("SELECT 1 AS one", sql_alchemy_json) + + assert result is not None + self.assertEqual(list(result["one"]), [1]) + + +class TestSetupStatementParserIntegration(TestCase): + """The parser strips a leading USE/SET/ALTER SESSION run off the cell SQL + and feeds it as setup_statements down to _execute_sql_on_engine. We can't + run USE WAREHOUSE against SQLite, so we observe the wired-through values + via a mock of the engine-level executor — the *behavior* (setup runs on + the same connection as main) is proven by the SQLite tests above.""" + + @mock.patch("deepnote_toolkit.sql.sql_execution._execute_sql_on_engine") + @mock.patch("sqlalchemy.engine.create_engine") + def test_leading_use_warehouse_extracted_from_cell( + self, _mocked_create_engine, mocked_execute_on_engine + ): + mocked_execute_on_engine.return_value = pd.DataFrame({"x": [1]}) + + sql_alchemy_json = json.dumps( + { + "url": "snowflake://u@a/?warehouse=&role=", + "params": {}, + "param_style": "pyformat", + "integration_id": "int_1", + } + ) + + execute_sql_with_connection_json( + "USE WAREHOUSE abc; USE ROLE r; SELECT 1", + sql_alchemy_json, + ) + + _, kwargs = mocked_execute_on_engine.call_args + # The main query is what's left after stripping the prefix; the + # setup statements were passed down for execution on the same + # connection. + passed_query = mocked_execute_on_engine.call_args.args[1] + self.assertEqual(passed_query.strip(), "SELECT 1") + self.assertEqual( + kwargs["setup_statements"], ["USE WAREHOUSE abc", "USE ROLE r"] + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution._execute_sql_on_engine") + @mock.patch("sqlalchemy.engine.create_engine") + def test_explicit_setup_statements_appended_after_parsed( + self, _mocked_create_engine, mocked_execute_on_engine + ): + mocked_execute_on_engine.return_value = pd.DataFrame({"x": [1]}) + + sql_alchemy_json = json.dumps( + { + "url": "snowflake://u@a/?warehouse=&role=", + "params": {}, + "param_style": "pyformat", + "integration_id": "int_1", + } + ) + + execute_sql_with_connection_json( + "USE WAREHOUSE abc; SELECT 1", + sql_alchemy_json, + setup_statements=["ALTER SESSION SET TIMEZONE = 'UTC'"], + ) + + _, kwargs = mocked_execute_on_engine.call_args + self.assertEqual( + kwargs["setup_statements"], + ["USE WAREHOUSE abc", "ALTER SESSION SET TIMEZONE = 'UTC'"], + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution._execute_sql_on_engine") + @mock.patch("sqlalchemy.engine.create_engine") + def test_no_setup_prefix_passes_cell_through_unchanged( + self, _mocked_create_engine, mocked_execute_on_engine + ): + mocked_execute_on_engine.return_value = pd.DataFrame({"x": [1]}) + + sql_alchemy_json = json.dumps( + { + "url": "snowflake://u@a/?warehouse=&role=", + "params": {}, + "param_style": "pyformat", + "integration_id": "int_1", + } + ) + + execute_sql_with_connection_json("SELECT 1", sql_alchemy_json) + + _, kwargs = mocked_execute_on_engine.call_args + passed_query = mocked_execute_on_engine.call_args.args[1] + self.assertEqual(passed_query.strip(), "SELECT 1") + self.assertEqual(kwargs["setup_statements"], []) + + def test_templated_value_in_use_warehouse_raises_clear_error(self): + """`USE WAREHOUSE {{ env }}` renders to a placeholder the driver can't + bind for USE WAREHOUSE. Surface this clearly instead of a confusing + driver error.""" + # Use a literal pyformat placeholder in the template so we don't + # depend on a Jinja variable being resolvable. + import __main__ + + __main__.env = "abc" + try: + sql_alchemy_json = json.dumps( + { + "url": "snowflake://u@a/?warehouse=&role=", + "params": {}, + "param_style": "pyformat", + "integration_id": "int_1", + } + ) + from deepnote_toolkit.sql.setup_statement_parser import ( + SetupStatementError, + ) + + with self.assertRaises(SetupStatementError) as ctx: + execute_sql_with_connection_json( + "USE WAREHOUSE {{ env }}; SELECT 1", sql_alchemy_json + ) + self.assertIn("setup_statements=", str(ctx.exception)) + finally: + del __main__.env + + class TestTrinoParamStyleAutoDetection(TestCase): """Tests for auto-detection of param_style for Trino connections""" diff --git a/tests/unit/test_sql_execution_internal.py b/tests/unit/test_sql_execution_internal.py index ad4ec00..8b8ae9f 100644 --- a/tests/unit/test_sql_execution_internal.py +++ b/tests/unit/test_sql_execution_internal.py @@ -377,6 +377,25 @@ def test_create_sql_ssh_uri_no_ssh(): assert url is None +def test_execute_sql_on_engine_aborts_main_query_when_setup_fails(): + """A failing setup statement must propagate and prevent the main query from running.""" + mock_cursor = mock.MagicMock() + mock_engine = _setup_mock_engine_with_cursor(mock_cursor) + sa_connection = mock_engine.begin.return_value.__enter__.return_value + sa_connection.exec_driver_sql = mock.Mock(side_effect=RuntimeError("bad warehouse")) + + with mock.patch("pandas.read_sql_query") as mock_read: + with pytest.raises(RuntimeError, match="bad warehouse"): + se._execute_sql_on_engine( + mock_engine, + "SELECT 1", + {}, + setup_statements=["USE WAREHOUSE missing"], + ) + + mock_read.assert_not_called() + + def test_create_sql_ssh_uri_missing_key(monkeypatch): def fake_get_env(name, default=None): if name == "PRIVATE_SSH_KEY_BLOB":