Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cratedb_toolkit/cfr/systable.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def save(self) -> Path:

path_table_schema = path_schema / f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}.sql"
path_table_data = path_data / f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}.{self.data_format}"
tablename_out = f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}"
tablename_out = self.adapter.quote_relation_name(f"{ExportSettings.TABLE_FILENAME_PREFIX}{tablename}")

# Write schema file.
with open(path_table_schema, "w") as fh_schema:
Expand Down
5 changes: 3 additions & 2 deletions cratedb_toolkit/testing/testcontainers/cratedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from cratedb_toolkit.testing.testcontainers.util import KeepaliveContainer, asbool
from cratedb_toolkit.util import DatabaseAdapter
from cratedb_toolkit.util.database import quote_table_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -189,7 +188,9 @@ def reset(self, tables: Optional[list] = None):
"""
if tables and self.database:
for reset_table in tables:
self.database.connection.exec_driver_sql(f"DROP TABLE IF EXISTS {quote_table_name(reset_table)};")
self.database.connection.exec_driver_sql(
f"DROP TABLE IF EXISTS {self.database.quote_relation_name(reset_table)};"
)

def get_connection_url(self, *args, **kwargs):
"""
Expand Down
60 changes: 37 additions & 23 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,38 @@ def __init__(self, dburi: str, echo: bool = False):
self.engine = sa.create_engine(self.dburi, echo=echo)
self.connection = self.engine.connect()

def quote_relation_name(self, ident: str) -> str:
"""
Quote the given, possibly full-qualified, relation name if needed.

In: foo
Out: foo

In: Foo
Out: "Foo"

In: "Foo"
Out: "Foo"

In: foo.bar
Out: "foo"."bar"

In: "foo.bar"
Out: "foo.bar"
"""
if ident[0] == '"' and ident[len(ident) - 1] == '"':
return ident
if "." in ident:
parts = ident.split(".")
if len(parts) > 2:
raise ValueError(f"Invalid relation name {ident}")
return (
self.engine.dialect.identifier_preparer.quote_schema(parts[0])
+ "."
+ self.engine.dialect.identifier_preparer.quote(parts[1])
)
return self.engine.dialect.identifier_preparer.quote(ident=ident)

def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None):
"""
Run SQL statement, and return results, optionally ignoring exceptions.
Expand Down Expand Up @@ -82,7 +114,7 @@ def count_records(self, name: str, errors: Literal["raise", "ignore"] = "raise")
"""
Return number of records in table.
"""
sql = f"SELECT COUNT(*) AS count FROM {quote_table_name(name)};" # noqa: S608
sql = f"SELECT COUNT(*) AS count FROM {self.quote_relation_name(name)};" # noqa: S608
try:
results = self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -96,7 +128,7 @@ def table_exists(self, name: str) -> bool:
"""
Check whether given table exists.
"""
sql = f"SELECT 1 FROM {quote_table_name(name)} LIMIT 1;" # noqa: S608
sql = f"SELECT 1 FROM {self.quote_relation_name(name)} LIMIT 1;" # noqa: S608
try:
self.run_sql(sql=sql)
return True
Expand All @@ -107,15 +139,15 @@ def refresh_table(self, name: str):
"""
Run a `REFRESH TABLE ...` command.
"""
sql = f"REFRESH TABLE {quote_table_name(name)};" # noqa: S608
sql = f"REFRESH TABLE {self.quote_relation_name(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

def prune_table(self, name: str, errors: Literal["raise", "ignore"] = "raise"):
"""
Run a `DELETE FROM ...` command.
"""
sql = f"DELETE FROM {quote_table_name(name)};" # noqa: S608
sql = f"DELETE FROM {self.quote_relation_name(name)};" # noqa: S608
try:
self.run_sql(sql=sql)
except ProgrammingError as ex:
Expand All @@ -129,7 +161,7 @@ def drop_table(self, name: str):
"""
Run a `DROP TABLE ...` command.
"""
sql = f"DROP TABLE IF EXISTS {quote_table_name(name)};" # noqa: S608
sql = f"DROP TABLE IF EXISTS {self.quote_relation_name(name)};" # noqa: S608
self.run_sql(sql=sql)
return True

Expand Down Expand Up @@ -332,21 +364,3 @@ def decode_database_table(url: str) -> t.Tuple[str, str]:
if url_.scheme == "crate" and not database:
database = url_.query_params.get("schema")
return database, table


def quote_table_name(name: str) -> str:
"""
Quote table name if not happened already.

In: foo
Out: "foo"

In: "foo"
Out: "foo"

In: foo.bar
Out: foo.bar
"""
if '"' not in name and "." not in name:
name = f'"{name}"'
return name
16 changes: 16 additions & 0 deletions tests/cfr/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def test_cfr_cli_export_failure(cratedb, tmp_path, caplog):
assert result.output == ""


def test_cfr_cli_export_ensure_table_name_is_quoted(cratedb, tmp_path, caplog):
runner = CliRunner(env={"CRATEDB_SQLALCHEMY_URL": cratedb.database.dburi, "CFR_TARGET": str(tmp_path)})
result = runner.invoke(
cli,
args="--debug sys-export",
catch_exceptions=False,
)
assert result.exit_code == 0

path = Path(json.loads(result.output)["path"])
sys_cluster_table_schema = path / "schema" / "sys-cluster.sql"
with open(sys_cluster_table_schema, "r") as f:
content = f.read()
assert '"sys-cluster"' in content, "Table name missing or not quoted"


def test_cfr_cli_import_success(cratedb, tmp_path, caplog):
"""
Verify `ctk cfr sys-import` works.
Expand Down
23 changes: 23 additions & 0 deletions tests/util/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from cratedb_toolkit.util import DatabaseAdapter


def test_quote_relation_name():
database = DatabaseAdapter(dburi="crate://localhost")
assert database.quote_relation_name("my_table") == "my_table"
assert database.quote_relation_name("my-table") == '"my-table"'
assert database.quote_relation_name("MyTable") == '"MyTable"'
assert database.quote_relation_name('"MyTable"') == '"MyTable"'
assert database.quote_relation_name("my_schema.my_table") == "my_schema.my_table"
assert database.quote_relation_name("my-schema.my_table") == '"my-schema".my_table'
assert database.quote_relation_name('"wrong-quoted-fqn.my_table"') == '"wrong-quoted-fqn.my_table"'
assert database.quote_relation_name('"my_schema"."my_table"') == '"my_schema"."my_table"'
# reserved keyword must be quoted
assert database.quote_relation_name("table") == '"table"'


def test_quote_relation_name_with_invalid_fqn():
database = DatabaseAdapter(dburi="crate://localhost")
with pytest.raises(ValueError):
database.quote_relation_name("my-db.my-schema.my-table")