diff --git a/superset/commands/sql_lab/estimate.py b/superset/commands/sql_lab/estimate.py index 36bae9de4129..f3fbc8bdaf41 100644 --- a/superset/commands/sql_lab/estimate.py +++ b/superset/commands/sql_lab/estimate.py @@ -102,14 +102,17 @@ def _apply_sql_security(self, sql: str) -> str: db_engine_spec.engine, set(), ) - if disallowed_tables and parsed_script.check_tables_present(disallowed_tables): - found_tables = set() - for statement in parsed_script.statements: - present = {table.table.lower() for table in statement.tables} - for table in disallowed_tables: - if table.lower() in present: - found_tables.add(table) - raise SupersetDisallowedSQLTableException(found_tables or disallowed_tables) + if disallowed_tables: + # Honors schema-qualified denylist entries (e.g. + # ``information_schema.tables``) and reports only the tables + # actually referenced by the query. Pass the selected schema so an + # unqualified reference that resolves to it at runtime (via the + # connection ``search_path``) matches too. + found_tables = parsed_script.get_disallowed_tables( + disallowed_tables, self._schema + ) + if found_tables: + raise SupersetDisallowedSQLTableException(found_tables) if parsed_script.has_mutation() and not self._database.allow_dml: raise SupersetDMLNotAllowedException() diff --git a/superset/config.py b/superset/config.py index 5380df762de2..9bd7582def9b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1830,6 +1830,20 @@ def engine_context_manager( # pylint: disable=unused-argument "pg_read_file", "pg_ls_dir", "pg_read_binary_file", + # PostgreSQL large-object functions: writers can plant arbitrary + # bytes on the server filesystem (lo_export, lo_from_bytea, lowrite, + # lo_put, lo_create, lo_import), readers can pull bytes back out + # (lo_get, loread), and lo_unlink deletes large objects outright. + # Defense-in-depth on top of is_mutating()'s function-name check. + "lo_from_bytea", + "lo_export", + "lo_import", + "lo_put", + "lo_create", + "lowrite", + "lo_get", + "loread", + "lo_unlink", # XML functions that can execute SQL "database_to_xml", "database_to_xmlschema", @@ -1920,6 +1934,30 @@ def engine_context_manager( # pylint: disable=unused-argument "pg_stat_replication", "pg_stat_wal_receiver", "pg_user", + # The SQL-standard `information_schema` views expose table / + # column / privilege / view-definition metadata across the entire + # database role the connection user can see. Entries are + # schema-qualified so `check_tables_present` only matches when the + # reference resolves to `information_schema.` -- either written + # explicitly or as an unqualified name under an `information_schema` + # search_path -- not any user table that happens to share a name. + "information_schema.tables", + "information_schema.columns", + "information_schema.schemata", + "information_schema.views", + "information_schema.routines", + "information_schema.role_table_grants", + "information_schema.role_column_grants", + "information_schema.role_routine_grants", + "information_schema.table_privileges", + "information_schema.column_privileges", + "information_schema.usage_privileges", + "information_schema.key_column_usage", + "information_schema.table_constraints", + "information_schema.referential_constraints", + "information_schema.view_table_usage", + "information_schema.applicable_roles", + "information_schema.enabled_roles", }, "mysql": { "mysql.user", diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 0f6963087240..89feb4ce0368 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1221,24 +1221,19 @@ def _process_sql_expression( # pylint: disable=too-many-arguments disallowed_functions ): raise SupersetDisallowedSQLFunctionException(disallowed_functions) - if disallowed_tables and parsed.check_tables_present(disallowed_tables): + if disallowed_tables: # Report only the tables actually found in the expression, # mirroring the canonical execution-time gate in # `superset.sql_lab._validate_query` so the user-facing - # error doesn't echo the operator's full denylist. - present_tables = { - table.table.lower() - for statement in parsed.statements - for table in statement.tables - } - found_tables = { - table - for table in disallowed_tables - if table.lower() in present_tables - } - raise SupersetDisallowedSQLTableException( - found_tables or disallowed_tables + # error doesn't echo the operator's full denylist. Honors + # schema-qualified denylist entries (e.g. + # ``information_schema.tables``) and resolves unqualified + # references against the query schema. + found_tables = parsed.get_disallowed_tables( + disallowed_tables, schema ) + if found_tables: + raise SupersetDisallowedSQLTableException(found_tables) return expression def _process_select_expression( @@ -1465,19 +1460,17 @@ def _raise_for_disallowed_sql(self, sql: str) -> None: disallowed_functions ): raise SupersetDisallowedSQLFunctionException(disallowed_functions) - if disallowed_tables and parsed_script.check_tables_present(disallowed_tables): + if disallowed_tables: # Report only the tables actually found in the query, mirroring the # canonical execution-time gate so the user-facing error doesn't - # echo the operator's full denylist. - present_tables = { - table.table.lower() - for statement in parsed_script.statements - for table in statement.tables - } - found_tables = { - table for table in disallowed_tables if table.lower() in present_tables - } - raise SupersetDisallowedSQLTableException(found_tables or disallowed_tables) + # echo the operator's full denylist. Honors schema-qualified + # denylist entries (e.g. ``information_schema.tables``) and resolves + # unqualified references against the query schema. + found_tables = parsed_script.get_disallowed_tables( + disallowed_tables, self.schema + ) + if found_tables: + raise SupersetDisallowedSQLTableException(found_tables) def query(self, query_obj: QueryObjectDict) -> QueryResult: """ diff --git a/superset/sql/execution/executor.py b/superset/sql/execution/executor.py index 13e12fc1a479..339b1f3583aa 100644 --- a/superset/sql/execution/executor.py +++ b/superset/sql/execution/executor.py @@ -231,7 +231,7 @@ def execute( ) # 2. Security checks on transformed script - self._check_security(transformed_script) + self._check_security(transformed_script, schema) # 3. Get mutation status and format SQL has_mutation = transformed_script.has_mutation() @@ -355,7 +355,7 @@ def execute_async( ) # 2. Security checks on transformed script - self._check_security(transformed_script) + self._check_security(transformed_script, schema) # 3. Get mutation status and format SQL has_mutation = transformed_script.has_mutation() @@ -449,11 +449,12 @@ def _prepare_sql( return original_script, transformed_script, catalog, schema - def _check_security(self, script: SQLScript) -> None: + def _check_security(self, script: SQLScript, schema: str | None = None) -> None: """ Perform security checks on prepared SQL script. :param script: Prepared SQLScript + :param schema: Effective schema unqualified references resolve to :raises SupersetSecurityException: If security checks fail """ # Check disallowed functions @@ -469,7 +470,7 @@ def _check_security(self, script: SQLScript) -> None: ) # Check disallowed tables - if disallowed_tables := self._check_disallowed_tables(script): + if disallowed_tables := self._check_disallowed_tables(script, schema): raise SupersetSecurityException( SupersetError( message=f"Disallowed SQL tables: {', '.join(disallowed_tables)}", @@ -702,11 +703,14 @@ def _check_disallowed_functions(self, script: SQLScript) -> set[str] | None: return found if found else None - def _check_disallowed_tables(self, script: SQLScript) -> set[str] | None: + def _check_disallowed_tables( + self, script: SQLScript, schema: str | None = None + ) -> set[str] | None: """ Check for disallowed SQL tables/views. :param script: Parsed SQL script + :param schema: Effective schema unqualified references resolve to :returns: Set of disallowed tables found, or None if none found """ disallowed_config = app.config.get("DISALLOWED_SQL_TABLES", {}) @@ -717,15 +721,11 @@ def _check_disallowed_tables(self, script: SQLScript) -> set[str] | None: if not engine_disallowed: return None - # Single-pass AST-based table detection - found: set[str] = set() - for statement in script.statements: - present = {table.table.lower() for table in statement.tables} - for table in engine_disallowed: - if table.lower() in present: - found.add(table) - - return found or None + # Honors schema-qualified denylist entries (e.g. + # ``information_schema.tables``) as well as bare names. The effective + # schema lets an unqualified reference that resolves to it at runtime + # (via the connection ``search_path``) match too. + return script.get_disallowed_tables(engine_disallowed, schema) or None def _apply_rls_to_script( self, script: SQLScript, catalog: str | None, schema: str | None diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 6a94a414d9da..67ff877eb79b 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -492,15 +492,32 @@ def check_functions_present(self, functions: set[str]) -> bool: """ raise NotImplementedError() - def check_tables_present(self, tables: set[str]) -> bool: + def check_tables_present( + self, tables: set[str], default_schema: str | None = None + ) -> bool: """ Check if any of the given tables are present in the statement. :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Schema unqualified references resolve to at + runtime (e.g. the session ``search_path`` / selected schema) :return: True if any of the tables are present """ raise NotImplementedError() + def get_disallowed_tables( + self, tables: set[str], default_schema: str | None = None + ) -> set[str]: + """ + Return the subset of ``tables`` referenced by this statement. + + :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Schema unqualified references resolve to at + runtime (e.g. the session ``search_path`` / selected schema) + :return: The matched entries, in their original denylist form + """ + raise NotImplementedError() + def get_limit_value(self) -> int | None: """ Get the limit value of the statement. @@ -592,6 +609,76 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): This class is used for all engines with dialects that can be parsed using sqlglot. """ + # Function names that mutate server-side state but appear in the AST as + # plain function calls inside a non-mutating wrapper. Used by + # ``is_mutating()`` to classify e.g. PostgreSQL large-object writers. + # Names are uppercased for comparison. + _MUTATING_FUNCTION_NAMES: frozenset[str] = frozenset( + { + "LO_FROM_BYTEA", + "LO_EXPORT", + "LO_IMPORT", + "LO_PUT", + "LO_CREATE", + "LOWRITE", + "LO_UNLINK", + # PostgreSQL sequence mutators. `SELECT setval('seq', N)` looks + # like a read but changes sequence state for every subsequent + # `nextval` caller. + "SETVAL", + } + ) + + # PostgreSQL constructs that sqlglot represents as an opaque ``exp.Command`` + # (no structured AST). Each can mutate server state or wrap a DML body that + # would otherwise be detected by node-type matching. Used by + # ``is_mutating()``. + _POSTGRES_MUTATING_COMMAND_NAMES: frozenset[str] = frozenset( + { + "DO", # PL/pgSQL anonymous block + "PREPARE", # PREPARE u AS UPDATE ... ; EXECUTE u + "EXECUTE", # body is the prepared DML + "CALL", # procedure body may mutate + "COPY", # server-side file ingest into a table + "GRANT", + "REVOKE", + "SET", # SET ROLE / SET SESSION AUTHORIZATION change effective user + "RESET", # RESET ROLE / RESET ALL reverts SET; same class as SET + "REFRESH", # REFRESH MATERIALIZED VIEW + "REINDEX", + "VACUUM", + # DDL head-tokens that sqlglot falls back to exp.Command for + # whenever the body uses syntax it does not model + # (CREATE EXTENSION/FUNCTION...LANGUAGE C/PUBLICATION/etc., + # ALTER ROLE/SYSTEM/..., DROP EXTENSION/RULE/...). Well-formed + # CREATE TABLE/ALTER TABLE/DROP TABLE are already caught by the + # node-type tuple; these entries close the fallback path. + "CREATE", + "ALTER", + "DROP", + "LOAD", # LOAD '/path/lib.so' dlopens a shared library on the PG host + # SHOW reads server configuration (version, hba_file, ssl state, + # search_path, etc.). It does not mutate, but in a read-only + # context (`allow_dml=False`) it is information-disclosure + # equivalent to the read-side entries in DISALLOWED_SQL_FUNCTIONS + # which are also blocked. Treat as gated alongside the writers. + "SHOW", + } + ) + + # Dialects where `SELECT ... INTO target` is CTAS (creates a table, and so + # mutates schema). Elsewhere the same syntax assigns into a variable and is + # a read: Oracle PL/SQL `SELECT ... INTO v` and MySQL `SELECT ... INTO @v` + # parse into an identical `exp.Select` with an `into` arg, so the dialect is + # the only signal that distinguishes the mutating form from the read form. + _SELECT_INTO_CTAS_DIALECTS: frozenset[Dialects] = frozenset( + { + Dialects.POSTGRES, + Dialects.REDSHIFT, + Dialects.TSQL, + } + ) + def __init__( self, statement: str | None = None, @@ -725,10 +812,40 @@ def is_mutating(self) -> bool: exp.Drop, exp.TruncateTable, exp.Alter, + # sqlglot has structured nodes for these DML/DCL forms in + # PostgreSQL and other dialects; without them an opaque exp.Command + # check would still miss the structured-parse path. + exp.Copy, # COPY FROM/TO (server-side file ingest) + exp.Grant, + exp.Revoke, + # COMMENT ON TABLE/COLUMN/etc. writes to system catalog pg_description. + exp.Comment, ) - for node_type in mutating_nodes: - if self._parsed.find(node_type): + if self._parsed.find(*mutating_nodes): + return True + + # `SELECT ... INTO new_table FROM ...` parses as `exp.Select` with an + # `into` arg (Postgres-style CTAS variant). It creates a new table and + # therefore mutates schema. Only treat it as mutating for dialects where + # the syntax is CTAS; elsewhere it assigns into a variable (a read). + if ( + self._dialect in self._SELECT_INTO_CTAS_DIALECTS + and isinstance(self._parsed, exp.Select) + and self._parsed.args.get("into") + ): + return True + + # Function calls that mutate server-side state without an enclosing + # mutating AST node. Notable example: PostgreSQL large-object writers + # (`lo_export` writes to the server filesystem, `lo_from_bytea`/ + # `lo_create`/`lo_put`/`lo_import`/`lowrite` mutate the pg_largeobject + # catalog). These appear as plain function calls inside an `exp.Select` + # and would otherwise pass the read-only gate. Every name in + # _MUTATING_FUNCTION_NAMES is dialect-specific and parses as an + # `exp.Anonymous`, whose `.name` is the bare function identifier. + for function in self._parsed.find_all(exp.Func): + if function.name.upper() in self._MUTATING_FUNCTION_NAMES: return True # depending on the dialect (Oracle, MS SQL) the `ALTER` is parsed as a @@ -736,14 +853,17 @@ def is_mutating(self) -> bool: if isinstance(self._parsed, exp.Command) and self._parsed.name == "ALTER": return True # pragma: no cover + # PostgreSQL constructs that sqlglot represents as an opaque + # `exp.Command` rather than a structured AST. Each of these can mutate + # state or wrap a DML body that would otherwise be detected. The + # `.name` attribute on `exp.Command` preserves the source-case of the + # head keyword (so `create extension ...` would yield `'create'`), + # which means the set lookup must be case-insensitive. if ( self._dialect == Dialects.POSTGRES and isinstance(self._parsed, exp.Command) - and self._parsed.name == "DO" + and self._parsed.name.upper() in self._POSTGRES_MUTATING_COMMAND_NAMES ): - # anonymous blocks can be written in many different languages (the default - # is PL/pgSQL), so parsing them it out of scope of this class; we just - # assume the anonymous block is mutating return True # Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see @@ -864,17 +984,76 @@ def check_functions_present(self, functions: set[str]) -> bool: else: present.add(function.name.upper()) + # MySQL `@@` syntax (also Oracle/SQL-Server `@@name`) parses as + # `exp.SessionParameter`, which is *not* a subclass of `exp.Func`, so + # the walk above misses it. Include those names so denylist entries + # like `version` or `hostname` match `SELECT @@version`. + for param in self._parsed.find_all(exp.SessionParameter): + present.add(param.name.upper()) + return any(function.upper() in present for function in functions) - def check_tables_present(self, tables: set[str]) -> bool: + def check_tables_present( + self, tables: set[str], default_schema: str | None = None + ) -> bool: """ Check if any of the given tables are present in the statement. + Denylist entries may be bare (``pg_stat_activity``) or + schema-qualified (``information_schema.tables``). Bare entries + match by table name regardless of schema; qualified entries + require the schema to match too. This lets us block all access + to ``information_schema`` without also blocking any + user-authored table that happens to be named ``tables``. + :param tables: Set of table names to check for (case-insensitive) - :return: True if any of the tables are present + :param default_schema: Schema unqualified references resolve to at + runtime (e.g. the session ``search_path`` / selected schema) + :return: True if any of the given tables is referenced + """ + return bool(self.get_disallowed_tables(tables, default_schema)) + + def get_disallowed_tables( + self, tables: set[str], default_schema: str | None = None + ) -> set[str]: """ - present = {table.table.lower() for table in self.tables} - return any(table.lower() in present for table in tables) + Return the subset of ``tables`` referenced by this statement. + + Matching mirrors :meth:`check_tables_present`: bare entries match by + table name regardless of schema, while schema-qualified entries + require the schema to match too. Entries are returned in their + original denylist form so callers can report exactly which + denylisted tables were hit. + + A reference without an explicit schema is resolved against + ``default_schema`` when one is supplied, so an unqualified ``tables`` + run under ``search_path = information_schema`` still matches the + ``information_schema.tables`` entry, while the same name under a + user schema does not. + + :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Schema unqualified references resolve to at + runtime (e.g. the session ``search_path`` / selected schema) + :return: The matched entries, in their original denylist form + """ + fallback = default_schema.lower() if default_schema else None + present_bare: set[str] = set() + present_qualified: set[str] = set() + for t in self.tables: + bare = t.table.lower() + present_bare.add(bare) + schema = t.schema.lower() if t.schema else fallback + if schema: + present_qualified.add(f"{schema}.{bare}") + found: set[str] = set() + for entry in tables: + needle = entry.lower() + if "." in needle: + if needle in present_qualified: + found.add(entry) + elif needle in present_bare: + found.add(entry) + return found def get_limit_value(self) -> int | None: """ @@ -1302,16 +1481,32 @@ def check_functions_present(self, functions: set[str]) -> bool: logger.warning("Kusto KQL doesn't support checking for functions present.") return False - def check_tables_present(self, tables: set[str]) -> bool: + def check_tables_present( + self, tables: set[str], default_schema: str | None = None + ) -> bool: """ Check if any of the given tables are present in the statement. :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Unused; accepted for interface parity :return: True if any of the tables are present """ logger.warning("Kusto KQL doesn't support checking for tables present.") return False + def get_disallowed_tables( + self, tables: set[str], default_schema: str | None = None + ) -> set[str]: + """ + Return the subset of ``tables`` referenced by this statement. + + :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Unused; accepted for interface parity + :return: The matched entries, in their original denylist form + """ + logger.warning("Kusto KQL doesn't support checking for tables present.") + return set() + def get_limit_value(self) -> int | None: """ Get the limit value of the statement. @@ -1483,16 +1678,34 @@ def check_functions_present(self, functions: set[str]) -> bool: for statement in self.statements ) - def check_tables_present(self, tables: set[str]) -> bool: + def check_tables_present( + self, tables: set[str], default_schema: str | None = None + ) -> bool: """ Check if any of the given tables are present in the script. :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Schema unqualified references resolve to at + runtime (e.g. the session ``search_path`` / selected schema) :return: True if any of the tables are present """ - return any( - statement.check_tables_present(tables) for statement in self.statements - ) + return bool(self.get_disallowed_tables(tables, default_schema)) + + def get_disallowed_tables( + self, tables: set[str], default_schema: str | None = None + ) -> set[str]: + """ + Return the subset of ``tables`` referenced anywhere in the script. + + :param tables: Set of table names to check for (case-insensitive) + :param default_schema: Schema unqualified references resolve to at + runtime (e.g. the session ``search_path`` / selected schema) + :return: The matched entries, in their original denylist form + """ + found: set[str] = set() + for statement in self.statements: + found |= statement.get_disallowed_tables(tables, default_schema) + return found def is_valid_ctas(self) -> bool: """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 34611d38a335..8ca538561966 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -433,15 +433,16 @@ def execute_sql_statements( # noqa: C901 db_engine_spec.engine, set(), ) - if disallowed_tables and parsed_script.check_tables_present(disallowed_tables): - # Report only the tables actually found in the query - found_tables = set() - for statement in parsed_script.statements: - present = {table.table.lower() for table in statement.tables} - for table in disallowed_tables: - if table.lower() in present: - found_tables.add(table) - raise SupersetDisallowedSQLTableException(found_tables or disallowed_tables) + if disallowed_tables: + # Report only the denylisted tables actually referenced in the query, + # honoring schema-qualified entries (e.g. ``information_schema.tables``). + # Pass the selected schema so an unqualified reference that resolves + # to it at runtime (via the connection ``search_path``) matches too. + found_tables = parsed_script.get_disallowed_tables( + disallowed_tables, query.schema + ) + if found_tables: + raise SupersetDisallowedSQLTableException(found_tables) if parsed_script.has_mutation() and not database.allow_dml: raise SupersetDMLNotAllowedException() diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 7160ea0c3d76..178df3a68d7e 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -26,6 +26,7 @@ from superset.exceptions import QueryClauseValidationException, SupersetParseError from superset.jinja_context import JinjaTemplateProcessor from superset.sql.parse import ( + BaseSQLStatement, CTASMethod, extract_tables_from_statement, JinjaSQLResult, @@ -1143,7 +1144,15 @@ def test_split_kql(kql: str, expected: list[str]) -> None: ("postgresql", "DROP TABLE foo", True), ("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False), ("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True), - ("postgresql", "SHOW search_path", False), + # SHOW reads server configuration (version, hba_file, ssl, + # search_path, etc.). It is information-disclosure equivalent to + # the read-side entries already in DISALLOWED_SQL_FUNCTIONS + # (pg_read_file, version, current_setting) and is gated alongside + # the writers so it is rejected when allow_dml=False. + ("postgresql", "SHOW search_path", True), + # SET search_path parses as exp.Set (a structured node), not + # exp.Command, so the SET-in-mutating-commands rule does NOT + # catch it. Pure GUC reads/writes stay non-mutating. ("postgresql", "SET search_path TO public", False), ( "postgres", @@ -1388,6 +1397,9 @@ def test_custom_dialect(app: None) -> None: ("SELECT 1", False), ("with source as ( select 1 as one ) select * from source", False), ("ALTER TABLE foo ADD COLUMN bar INT", True), + # COMMENT ON parses as a typed exp.Comment node across dialects; it + # writes to the catalog (pg_description on Postgres) so it is gated. + ("COMMENT ON TABLE t IS 'note'", True), ], ) def test_is_mutating(sql: str, engine: str, expected: bool) -> None: @@ -1434,6 +1446,192 @@ def test_is_mutating_anonymous_block(sql: str, expected: bool) -> None: assert SQLStatement(sql, "postgresql").is_mutating() == expected +@pytest.mark.parametrize( + "sql, expected", + [ + # PostgreSQL large-object writers: each mutates server state. The bare + # SELECT wrapper is irrelevant because the function call itself is the + # side effect. + ("SELECT lo_from_bytea(0, decode('deadbeef', 'hex'))", True), + ("SELECT lo_export(12345, '/tmp/payload.bin')", True), + ("SELECT lo_import('/etc/passwd')", True), + ("SELECT lo_put(12345, 0, decode('00', 'hex'))", True), + ("SELECT lo_create(0)", True), + ("SELECT lowrite(12345, decode('00', 'hex'))", True), + # lo_unlink deletes a large object outright. + ("SELECT lo_unlink(12345)", True), + # PostgreSQL sequence mutator. setval() looks like a read but + # advances sequence state for every subsequent nextval caller. + ("SELECT setval('public.my_seq', 1000)", True), + ("SELECT SETVAL('public.my_seq', 1)", True), + # Read-side large-object functions are intentionally NOT classified + # as mutating here. They are still blocked via the function denylist + # (see DISALLOWED_SQL_FUNCTIONS) but they do not write state. + ("SELECT lo_get(12345)", False), + ("SELECT loread(12345, 1024)", False), + # Case-insensitive matching: the AST stores the raw casing for + # anonymous functions, the check uppercases both sides. + ("SELECT LO_EXPORT(12345, '/tmp/x')", True), + # `SELECT INTO new_table FROM existing` creates a new relation; treat + # as mutating even though sqlglot parses it as exp.Select. + ("SELECT * INTO new_table FROM existing_table", True), + ("SELECT col INTO TEMP new_table FROM existing_table", True), + # Plain SELECT must remain non-mutating. + ("SELECT 1", False), + ("SELECT * FROM users WHERE id = 1", False), + ], +) +def test_is_mutating_postgres_function_and_select_into( + sql: str, expected: bool +) -> None: + """ + `is_mutating` must catch mutating function calls (PostgreSQL large-object + writers) and `SELECT ... INTO new_table` even though the wrapping AST + node is a plain `exp.Select`. + """ + assert SQLStatement(sql, "postgresql").is_mutating() == expected + + +@pytest.mark.parametrize( + "engine, sql", + [ + # `SELECT ... INTO new_table` is CTAS only in Postgres/Redshift/T-SQL. + # In Oracle PL/SQL and MySQL the same syntax assigns into a variable + # and is a read, so it must NOT be classified as mutating. + ("oracle", "SELECT col INTO v FROM existing_table"), + ("mysql", "SELECT col INTO @v FROM existing_table"), + ], +) +def test_is_mutating_select_into_variable_is_read(engine: str, sql: str) -> None: + """ + `SELECT ... INTO target` is only CTAS (mutating) for dialects where the + syntax creates a table. On Oracle/MySQL it assigns into a variable and is + a read, so `is_mutating` must return False there. + """ + assert SQLStatement(sql, engine).is_mutating() is False + + +@pytest.mark.parametrize( + "engine, sql", + [ + # `SELECT ... INTO new_table` is CTAS on Redshift and T-SQL just as it + # is on Postgres, so each dialect in _SELECT_INTO_CTAS_DIALECTS must + # classify the statement as mutating. + ("redshift", "SELECT * INTO new_table FROM existing_table"), + ("redshift", "SELECT col INTO new_table FROM existing_table"), + ("mssql", "SELECT * INTO new_table FROM existing_table"), + ("mssql", "SELECT col INTO new_table FROM existing_table"), + ], +) +def test_is_mutating_select_into_ctas_dialects(engine: str, sql: str) -> None: + """ + `SELECT ... INTO new_table` creates a table on the CTAS dialects beyond + Postgres (Redshift, T-SQL), so `is_mutating` must return True there. + """ + assert SQLStatement(sql, engine).is_mutating() is True + + +@pytest.mark.parametrize( + "sql, expected", + [ + # PostgreSQL constructs that sqlglot parses as opaque exp.Command. + # Each can wrap a DML body or change effective server state. + ("PREPARE u AS UPDATE t SET x = 1", True), + ("PREPARE i AS INSERT INTO t VALUES (1)", True), + ("EXECUTE my_plan", True), + ("CALL my_writing_procedure()", True), + ("COPY t FROM '/tmp/data.csv'", True), + ("GRANT SELECT ON t TO public", True), + ("REVOKE SELECT ON t FROM public", True), + ("SET ROLE other_role", True), + ("REFRESH MATERIALIZED VIEW mv", True), + ("REINDEX TABLE t", True), + ("VACUUM t", True), + # SHOW commands disclose server configuration (version, hba_file, + # ssl, search_path, etc.). They are read-only but treated as gated + # so they are blocked by the same allow_dml=False gate that blocks + # the writers; this mirrors the existing read-side blocks in + # DISALLOWED_SQL_FUNCTIONS for pg_read_file, version(), etc. + ("SHOW search_path", True), + ("SHOW all", True), + ("SHOW server_version", True), + # RESET reverts a prior SET (e.g. RESET ROLE backs out SET ROLE). + ("RESET ROLE", True), + # DDL head-tokens that sqlglot falls back to exp.Command for when the + # body uses syntax it does not model. One representative per + # head-token branch (CREATE/ALTER/DROP); they all hit the same + # set-lookup so additional CREATE PUBLICATION/SUBSCRIPTION/etc. + # cases would not add coverage. + ( + "CREATE FUNCTION x() RETURNS int AS '/tmp/x.so', 'i' LANGUAGE C", + True, + ), + ("CREATE EXTENSION pg_trgm", True), # non-FUNCTION DDL via Command + ("ALTER SYSTEM SET wal_level = 'logical'", True), + ("DROP EXTENSION pg_trgm", True), + # LOAD dlopens a shared library on the PG host. Same RCE primitive + # as `CREATE FUNCTION ... LANGUAGE C` if the library path is + # attacker-controlled (e.g. via a prior COPY-to-program foothold). + ("LOAD '/tmp/x.so'", True), + # Case-insensitive: sqlglot preserves source case on Command.name, + # so the set lookup must normalise. Regression for the original + # bug where a lowercase head-token bypassed the gate. + ("create extension pg_trgm", True), + ("load '/tmp/x.so'", True), + # Pre-existing positive controls + ("DO $$ BEGIN UPDATE t SET x = 1; END $$", True), + ("EXPLAIN ANALYZE UPDATE t SET x = 1", True), + ], +) +def test_is_mutating_postgres_command_constructs(sql: str, expected: bool) -> None: + """ + Several PostgreSQL constructs are represented by sqlglot as opaque + `exp.Command` nodes (no structured AST). `is_mutating` recognises them + by command name so they cannot slip past the read-only gate. + """ + assert SQLStatement(sql, "postgresql").is_mutating() == expected + + +@pytest.mark.parametrize( + "sql, engine, functions, expected", + [ + # MySQL `@@` syntax parses as exp.SessionParameter, which is + # not a subclass of exp.Func. The walker must include it so the + # denylist entry for `version` still catches `SELECT @@version`. + ("SELECT @@version", "mysql", {"version"}, True), + ("SELECT @@global.version", "mysql", {"version"}, True), + ("SELECT @@hostname", "mysql", {"hostname"}, True), + ("SELECT @@datadir", "mysql", {"datadir"}, True), + # Negative control: a session parameter not in the denylist must + # not match. + ("SELECT @@autocommit", "mysql", {"version", "hostname"}, False), + # A plain SELECT does not introduce session-parameter names. + ("SELECT 1", "mysql", {"version"}, False), + # The pre-existing exp.Func walk still works for normal calls. + ("SELECT version()", "mysql", {"version"}, True), + # PostgreSQL large-object functions are exp.Anonymous calls. The + # walk includes them; the denylist entry catches them. + ("SELECT lo_export(12345, '/tmp/x')", "postgresql", {"lo_export"}, True), + ( + "SELECT lo_from_bytea(0, decode('00','hex'))", + "postgresql", + {"lo_from_bytea"}, + True, + ), + ("SELECT loread(12345, 1024)", "postgresql", {"loread"}, True), + ], +) +def test_check_functions_present_session_parameter( + sql: str, engine: str, functions: set[str], expected: bool +) -> None: + """ + `check_functions_present` must visit `exp.SessionParameter` so that + denylist entries for names like `version` or `hostname` also match + `SELECT @@version` / `SELECT @@hostname` in MySQL. + """ + assert SQLScript(sql, engine).check_functions_present(functions) == expected + + @pytest.mark.parametrize( "sql, expected", [ @@ -3215,6 +3413,298 @@ def test_check_tables_present(sql: str, engine: str, expected: bool) -> None: assert SQLScript(sql, engine).check_tables_present(tables) == expected +@pytest.mark.parametrize( + "engine, sql, denylist, expected", + [ + # Postgres: schema-qualified denylist entry matches schema-qualified + # reference. + ( + "postgresql", + "SELECT * FROM information_schema.tables", + {"information_schema.tables"}, + True, + ), + # ... and is case-insensitive. + ( + "postgresql", + "SELECT * FROM INFORMATION_SCHEMA.TABLES", + {"information_schema.tables"}, + True, + ), + # Schema-qualified denylist entry does NOT match a bare-name table + # of the same name in another schema. A user table named `tables` + # remains queryable. + ( + "postgresql", + "SELECT * FROM public.tables", + {"information_schema.tables"}, + False, + ), + ( + "postgresql", + "SELECT * FROM tables", + {"information_schema.tables"}, + False, + ), + # Bare-name denylist entry still matches by table name only + # (existing behavior, schema-agnostic). + ( + "postgresql", + "SELECT * FROM pg_stat_activity", + {"pg_stat_activity"}, + True, + ), + ( + "postgresql", + "SELECT * FROM pg_catalog.pg_stat_activity", + {"pg_stat_activity"}, + True, + ), + # Mixed entries: one schema-qualified, one bare. Match either. + ( + "postgresql", + "SELECT * FROM information_schema.columns", + {"information_schema.tables", "information_schema.columns"}, + True, + ), + ( + "postgresql", + "SELECT * FROM pg_roles", + {"information_schema.tables", "pg_roles"}, + True, + ), + # Negative control. + ( + "postgresql", + "SELECT * FROM my_table", + {"information_schema.tables", "pg_roles"}, + False, + ), + # MySQL: the shipped DISALLOWED_SQL_TABLES['mysql'] entries are all + # schema-qualified (`mysql.user`, `performance_schema.threads`, + # `performance_schema.processlist`). Without schema-aware matching + # the entries are dead config. These cases pin the fix. + ( + "mysql", + "SELECT user, host, authentication_string FROM mysql.user", + {"mysql.user"}, + True, + ), + ( + "mysql", + "SELECT * FROM performance_schema.threads", + {"performance_schema.threads"}, + True, + ), + ( + "mysql", + "SELECT * FROM performance_schema.processlist", + {"performance_schema.processlist"}, + True, + ), + # MySQL must NOT block a user-authored table that shares the leaf + # name with the system view. + ( + "mysql", + "SELECT * FROM mydb.user", + {"mysql.user"}, + False, + ), + # MSSQL: same shape, `sys.*` entries are schema-qualified. + ( + "mssql", + "SELECT name, password_hash FROM sys.sql_logins", + {"sys.sql_logins"}, + True, + ), + ( + "mssql", + "SELECT name, sid FROM sys.server_principals", + {"sys.server_principals"}, + True, + ), + ( + "mssql", + "SELECT * FROM sys.configurations", + {"sys.configurations"}, + True, + ), + # MSSQL must NOT block a user-authored table sharing the leaf name. + ( + "mssql", + "SELECT * FROM mydb.sql_logins", + {"sys.sql_logins"}, + False, + ), + ], +) +def test_check_tables_present_schema_qualified( + engine: str, sql: str, denylist: set[str], expected: bool +) -> None: + """ + `check_tables_present` must distinguish schema-qualified denylist + entries (e.g. `information_schema.tables`, `mysql.user`, + `sys.sql_logins`) from bare-name entries (e.g. `pg_stat_activity`). + Schema-qualified entries only match schema-qualified references in + the SQL; bare entries match the table name regardless of schema. + + Covers Postgres, MySQL, and MSSQL dialects so the shipped + DISALLOWED_SQL_TABLES entries for each remain effective. + """ + assert SQLScript(sql, engine).check_tables_present(denylist) == expected + + +@pytest.mark.parametrize( + "engine, sql, denylist, expected", + [ + # A schema-qualified match is reported in its original denylist form, + # not collapsed to the bare leaf name and not the whole denylist. + ( + "postgresql", + "SELECT * FROM information_schema.tables", + {"information_schema.tables", "information_schema.columns", "pg_roles"}, + {"information_schema.tables"}, + ), + # Bare-name match is reported as-is. + ( + "postgresql", + "SELECT * FROM pg_catalog.pg_stat_activity", + {"pg_stat_activity", "pg_roles"}, + {"pg_stat_activity"}, + ), + # Multiple references across statements union their matches. + ( + "postgresql", + "SELECT * FROM information_schema.tables; SELECT * FROM pg_roles", + {"information_schema.tables", "pg_roles", "pg_settings"}, + {"information_schema.tables", "pg_roles"}, + ), + # No match returns an empty set. + ( + "postgresql", + "SELECT * FROM my_table", + {"information_schema.tables", "pg_roles"}, + set(), + ), + ], +) +def test_get_disallowed_tables( + engine: str, sql: str, denylist: set[str], expected: set[str] +) -> None: + """ + `get_disallowed_tables` returns exactly the denylist entries referenced, + in their original (possibly schema-qualified) form, so callers can report + precisely which tables were hit instead of echoing the whole denylist. + """ + assert SQLScript(sql, engine).get_disallowed_tables(denylist) == expected + + +@pytest.mark.parametrize( + "sql, default_schema, denylist, expected", + [ + # Unqualified reference resolves to the default schema, so it matches + # a schema-qualified denylist entry when the schemas line up (e.g. a + # connection whose search_path is `information_schema`). + ( + "SELECT * FROM tables", + "information_schema", + {"information_schema.tables"}, + {"information_schema.tables"}, + ), + # ... case-insensitively. + ( + "SELECT * FROM tables", + "INFORMATION_SCHEMA", + {"information_schema.tables"}, + {"information_schema.tables"}, + ), + # The same unqualified name under a user schema must NOT match: a user + # table named `tables` stays queryable. + ( + "SELECT * FROM tables", + "public", + {"information_schema.tables"}, + set(), + ), + # An explicit schema on the reference wins over the default schema. + ( + "SELECT * FROM public.tables", + "information_schema", + {"information_schema.tables"}, + set(), + ), + # Without a default schema, behavior is unchanged: unqualified + # references never match schema-qualified entries. + ( + "SELECT * FROM tables", + None, + {"information_schema.tables"}, + set(), + ), + # Bare-name denylist entries are schema-agnostic and unaffected by the + # default schema. + ( + "SELECT * FROM pg_stat_activity", + "information_schema", + {"pg_stat_activity"}, + {"pg_stat_activity"}, + ), + # The default schema is forwarded to every statement in a script, so an + # unqualified reference in a later statement is resolved too. + ( + "SELECT * FROM my_table; SELECT * FROM tables", + "information_schema", + {"information_schema.tables"}, + {"information_schema.tables"}, + ), + ], +) +def test_get_disallowed_tables_default_schema( + sql: str, + default_schema: str | None, + denylist: set[str], + expected: set[str], +) -> None: + """ + `get_disallowed_tables` resolves an unqualified reference against the + supplied default schema, so a denylisted system view (e.g. + `information_schema.tables`) is still caught when reached without an + explicit schema under that search_path, without blocking a same-named + user table under a different schema. + """ + assert ( + SQLScript(sql, "postgresql").get_disallowed_tables(denylist, default_schema) + == expected + ) + + +@pytest.mark.parametrize( + "sql, denylist, expected", + [ + ("SELECT * FROM pg_stat_activity", {"pg_stat_activity"}, True), + ("SELECT * FROM my_table", {"pg_stat_activity"}, False), + ], +) +def test_statement_check_tables_present( + sql: str, denylist: set[str], expected: bool +) -> None: + """ + `SQLStatement.check_tables_present` is the per-statement entry point that + `SQLScript` no longer routes through (it calls `get_disallowed_tables` + directly), so exercise it on its own to keep the override covered. + """ + assert SQLStatement(sql, "postgresql").check_tables_present(denylist) == expected + + +def test_kustokql_statement_check_tables_present() -> None: + """ + `KustoKQLStatement.check_tables_present` is unsupported and always reports + False; exercise it directly so the override stays covered. + """ + statement = KustoKQLStatement("foo | take 100", "kustokql") + assert statement.check_tables_present({"foo"}) is False + + @pytest.mark.parametrize( "kql, expected", [ @@ -3395,6 +3885,17 @@ def test_backtick_invalid_sql_still_fails() -> None: SQLScript(sql, "base") +def test_base_sql_statement_is_destructive_raises_not_implemented() -> None: + """ + BaseSQLStatement.is_destructive is abstract; both concrete subclasses + (SQLStatement and KustoKQLStatement) override it, so calling the base + implementation directly must raise. This exercises the abstract stub + so it stays exercised under coverage. + """ + with pytest.raises(NotImplementedError): + BaseSQLStatement.is_destructive(object()) # type: ignore[arg-type] + + def test_backtick_fallback_logs_warning(caplog: pytest.LogCaptureFixture) -> None: """ Test that the MySQL dialect fallback emits a warning log.