diff --git a/db2sql/application/use_cases/migrate_database.py b/db2sql/application/use_cases/migrate_database.py index 3f25590..4531874 100644 --- a/db2sql/application/use_cases/migrate_database.py +++ b/db2sql/application/use_cases/migrate_database.py @@ -94,7 +94,10 @@ def _load_data(self, database: Database) -> None: rows = self._reader.iter_query_rows(table.source_query, limit=limit) else: rows = self._reader.iter_rows(schema_name, table, limit=limit) - self._writer.bulk_load(schema_name, table, rows) + mapped = options.mapping_schemas.get( + schema_name, options.mapping_schemas.get("*", schema_name) + ) + self._writer.bulk_load(mapped, table, rows) @staticmethod def _resolve_limit( diff --git a/db2sql/domain/model/__init__.py b/db2sql/domain/model/__init__.py index 546fb4f..3a16b19 100644 --- a/db2sql/domain/model/__init__.py +++ b/db2sql/domain/model/__init__.py @@ -2,8 +2,8 @@ from .column import Column from .database import Database -from .foreign_key import ForeignKey +from .foreign_key import ForeignKey, ForeignKeyConstraint from .schema import Schema from .table import Table -__all__ = ["Column", "Database", "ForeignKey", "Schema", "Table"] +__all__ = ["Column", "Database", "ForeignKey", "ForeignKeyConstraint", "Schema", "Table"] diff --git a/db2sql/domain/model/foreign_key.py b/db2sql/domain/model/foreign_key.py index f66f4a0..f6cdf7c 100644 --- a/db2sql/domain/model/foreign_key.py +++ b/db2sql/domain/model/foreign_key.py @@ -1,8 +1,9 @@ -"""Foreign key value object.""" +"""Foreign key value objects.""" from __future__ import annotations from dataclasses import dataclass +from typing import Tuple @dataclass(frozen=True) @@ -12,3 +13,14 @@ class ForeignKey: schema: str table: str column: str + + +@dataclass(frozen=True) +class ForeignKeyConstraint: + """A named FK constraint grouping one or more columns. Immutable value object.""" + + name: str + ref_schema: str + ref_table: str + columns: Tuple[str, ...] + ref_columns: Tuple[str, ...] diff --git a/db2sql/domain/model/table.py b/db2sql/domain/model/table.py index f91a728..f711bcf 100644 --- a/db2sql/domain/model/table.py +++ b/db2sql/domain/model/table.py @@ -8,6 +8,7 @@ from db2sql.domain.errors import DuplicatedColumnError from .column import Column +from .foreign_key import ForeignKeyConstraint @dataclass @@ -17,6 +18,7 @@ class Table: name: str columns: Dict[str, Column] = field(default_factory=dict) indexes: Dict[str, List[str]] = field(default_factory=dict) + foreign_key_constraints: List[ForeignKeyConstraint] = field(default_factory=list) source_query: Optional[str] = None def add_column(self, column: Column) -> None: diff --git a/db2sql/domain/policy/identifier.py b/db2sql/domain/policy/identifier.py index 656926e..bda2ed4 100644 --- a/db2sql/domain/policy/identifier.py +++ b/db2sql/domain/policy/identifier.py @@ -8,7 +8,14 @@ def to_snake_case(name: str) -> str: - """Convert CamelCase to snake_case (idempotent on already-snake_case input).""" + """Convert CamelCase to snake_case (idempotent on already-snake_case input). + + Pure-uppercase names (e.g. ``CLIENTS``) are lowered without inserting + underscores, so ``CLIENTS`` becomes ``clients`` rather than ``c_l_i_e_n_t_s``. + Mixed-case names like ``ClientName`` become ``client_name``. + """ + if name.isupper() or name.islower(): + return name.lower() return _CAMEL_RE.sub("_", name).lower() diff --git a/db2sql/infrastructure/config/loader.py b/db2sql/infrastructure/config/loader.py index 86b8579..35ac7a0 100644 --- a/db2sql/infrastructure/config/loader.py +++ b/db2sql/infrastructure/config/loader.py @@ -139,6 +139,9 @@ def merge_cli_overrides(config: AppConfig, options: Mapping[str, Any]) -> AppCon dump["default_data_format"] = value elif key == "output_file_name": data["output_file"] = value + elif key == "target_schema": + schemas = dump.setdefault("mapping_schemas", {}) + schemas["*"] = value elif key == "split_size": data["split_size"] = value diff --git a/db2sql/infrastructure/emit/mssql/emitter.py b/db2sql/infrastructure/emit/mssql/emitter.py index c24c0f7..b6c0d22 100644 --- a/db2sql/infrastructure/emit/mssql/emitter.py +++ b/db2sql/infrastructure/emit/mssql/emitter.py @@ -102,7 +102,7 @@ def quote_identifier(self, name: str) -> str: return "[{}]".format(normalized.replace("]", "]]")) def schema_name(self, schema: Schema) -> str: - mapped = self._schema_mapping.get(schema.name, schema.name) + mapped = self._schema_mapping.get(schema.name, self._schema_mapping.get("*", schema.name)) return self.quote_identifier(mapped) def table_name(self, schema: Schema, table: Table) -> str: @@ -148,7 +148,9 @@ def emit_epilogue(self, sink: OutputSink) -> None: def emit_schemas(self, database: Database, sink: OutputSink) -> None: emitted = set() for schema in database.schemas.values(): - target = self._schema_mapping.get(schema.name, schema.name) + target = self._schema_mapping.get( + schema.name, self._schema_mapping.get("*", schema.name) + ) if target in emitted: continue emitted.add(target) @@ -213,22 +215,20 @@ def emit_foreign_keys(self, database: Database, sink: OutputSink) -> None: for schema in database.schemas.values(): for table in schema.tables.values(): qualified = self.table_name(schema, table) - for column in table.columns.values(): - fk = column.foreign_key - if not fk: - continue - ref_schema = database.schemas.get(fk.schema) + for fkc in table.foreign_key_constraints: + ref_schema = database.schemas.get(fkc.ref_schema) if ref_schema is None: continue - ref_table = ref_schema.get_table(fk.table) + ref_table = ref_schema.get_table(fkc.ref_table) if ref_table is None: continue ref_qualified = self.table_name(ref_schema, ref_table) + cols = ", ".join(self.quote_identifier(c) for c in fkc.columns) + ref_cols = ", ".join(self.quote_identifier(c) for c in fkc.ref_columns) sink.write( f"ALTER TABLE {qualified} " - f"ADD FOREIGN KEY ({self.quote_identifier(column.name)}) " - f"REFERENCES {ref_qualified} " - f"({self.quote_identifier(fk.column)});\n" + f"ADD FOREIGN KEY ({cols}) " + f"REFERENCES {ref_qualified} ({ref_cols});\n" ) sink.boundary() sink.write("\n") diff --git a/db2sql/infrastructure/emit/postgres/emitter.py b/db2sql/infrastructure/emit/postgres/emitter.py index b7d04f5..a5c35d2 100644 --- a/db2sql/infrastructure/emit/postgres/emitter.py +++ b/db2sql/infrastructure/emit/postgres/emitter.py @@ -2,12 +2,27 @@ from __future__ import annotations +import re from typing import Any, Dict, Iterable, Mapping, Optional from db2sql.application.ports import OutputSink from db2sql.domain.model import Column, Database, Schema, Table from db2sql.domain.policy import drop_order, normalize_identifier, topological_order +_DEFAULT_FUNCTION_MAP: Dict[str, str] = { + "getdate()": "now()", + "getutcdate()": "(now() AT TIME ZONE 'UTC')", + "sysdatetime()": "now()", + "sysutcdatetime()": "(now() AT TIME ZONE 'UTC')", + "newid()": "gen_random_uuid()", + "newsequentialid()": "gen_random_uuid()", +} + +_DEFAULT_FUNCTION_RE = re.compile( + "|".join(re.escape(k) for k in _DEFAULT_FUNCTION_MAP), + re.IGNORECASE, +) + class PostgresSqlEmitter: """Produce PostgreSQL DDL+DML for a collected :class:`Database`.""" @@ -93,7 +108,7 @@ def quote_identifier(self, name: str) -> str: return '"{}"'.format(normalized.replace('"', '""')) def schema_name(self, schema: Schema) -> str: - mapped = self._schema_mapping.get(schema.name, schema.name) + mapped = self._schema_mapping.get(schema.name, self._schema_mapping.get("*", schema.name)) return self.quote_identifier(mapped) def table_name(self, schema: Schema, table: Table) -> str: @@ -106,7 +121,17 @@ def _map_type(self, column: Column) -> str: if column.char_length and column.char_length > 0: return f"{target}({column.char_length})" if target == "numeric" and column.precision: - scale = column.scale or 0 + scale = column.scale if column.scale is not None else 0 + if column.scale is not None and scale == 0: + # numeric(p,0) is an integer stored in a decimal type (common + # MSSQL pattern). Promote to a native integer type so that + # serial/bigserial works for identity columns and FK types + # stay consistent across referencing tables. + if column.precision > 9: + return "bigint" + if column.precision > 4: + return "integer" + return "smallint" return f"numeric({column.precision},{scale})" return target @@ -118,9 +143,20 @@ def column_definition(self, column: Column) -> str: if not column.nullable and not column.identity: parts.append("NOT NULL") if column.default is not None and not column.identity: - parts.append(f"DEFAULT {column.default}") + parts.append(f"DEFAULT {self._translate_default(column.default)}") return " ".join(parts) + @staticmethod + def _translate_default(default: str) -> str: + """Translate source-dialect default expressions to PostgreSQL.""" + # Strip wrapping parentheses added by MSSQL (e.g. "(getdate())") + stripped = default.strip() + while stripped.startswith("(") and stripped.endswith(")"): + stripped = stripped[1:-1].strip() + return _DEFAULT_FUNCTION_RE.sub( + lambda m: _DEFAULT_FUNCTION_MAP[m.group(0).lower()], stripped + ) + # ---- emit ------------------------------------------------------------- def emit_prologue(self, sink: OutputSink) -> None: @@ -134,7 +170,9 @@ def emit_epilogue(self, sink: OutputSink) -> None: def emit_schemas(self, database: Database, sink: OutputSink) -> None: emitted = set() for schema in database.schemas.values(): - target = self._schema_mapping.get(schema.name, schema.name) + target = self._schema_mapping.get( + schema.name, self._schema_mapping.get("*", schema.name) + ) if target in emitted: continue emitted.add(target) @@ -190,22 +228,20 @@ def emit_foreign_keys(self, database: Database, sink: OutputSink) -> None: for schema in database.schemas.values(): for table in schema.tables.values(): qualified = self.table_name(schema, table) - for column in table.columns.values(): - fk = column.foreign_key - if not fk: - continue - ref_schema = database.schemas.get(fk.schema) + for fkc in table.foreign_key_constraints: + ref_schema = database.schemas.get(fkc.ref_schema) if ref_schema is None: continue - ref_table = ref_schema.get_table(fk.table) + ref_table = ref_schema.get_table(fkc.ref_table) if ref_table is None: continue ref_qualified = self.table_name(ref_schema, ref_table) + cols = ", ".join(self.quote_identifier(c) for c in fkc.columns) + ref_cols = ", ".join(self.quote_identifier(c) for c in fkc.ref_columns) sink.write( f"ALTER TABLE {qualified} " - f"ADD FOREIGN KEY ({self.quote_identifier(column.name)}) " - f"REFERENCES {ref_qualified} " - f"({self.quote_identifier(fk.column)});\n" + f"ADD FOREIGN KEY ({cols}) " + f"REFERENCES {ref_qualified} ({ref_cols});\n" ) sink.boundary() sink.write("\n") diff --git a/db2sql/infrastructure/persistence/mssql/reader.py b/db2sql/infrastructure/persistence/mssql/reader.py index b43056a..84c5297 100644 --- a/db2sql/infrastructure/persistence/mssql/reader.py +++ b/db2sql/infrastructure/persistence/mssql/reader.py @@ -2,13 +2,21 @@ from __future__ import annotations -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple +from urllib.parse import quote_plus from sqlalchemy import create_engine, engine, text from sqlalchemy.orm.session import Session, sessionmaker from db2sql.application.ports import Logger -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import ( + Column, + Database, + ForeignKey, + ForeignKeyConstraint, + Schema, + Table, +) from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.persistence import query_introspection from db2sql.infrastructure.persistence.errors import SourceReaderError @@ -35,8 +43,8 @@ def _connection_string(self) -> str: server = self._config.server port = f":{server.port}" if server.port else "" return "mssql+pymssql://{}:{}@{}{}/{}".format( - server.username or "", - server.password or "", + quote_plus(server.username or ""), + quote_plus(server.password or ""), server.hostname or "", port, server.dbname or "", @@ -254,21 +262,46 @@ def _read_foreign_keys(self, database: Database) -> None: AND KCU2.CONSTRAINT_NAME = RC.UNIQUE_CONSTRAINT_NAME WHERE KCU1.ORDINAL_POSITION = KCU2.ORDINAL_POSITION AND KCU1.TABLE_SCHEMA not in ('sys', 'guest', 'information_schema') -ORDER BY CONSTRAINT_SCHEMA, CONSTRAINT_NAME +ORDER BY CONSTRAINT_SCHEMA, CONSTRAINT_NAME, KCU1.ORDINAL_POSITION """ ) ) + # Group FK rows by (schema, table, constraint_name) + groups: Dict[Tuple[str, str, str], List[Any]] = {} for row in r: - table = database.get_table(row.TABLE_SCHEMA, row.TABLE_NAME) - if table: + key = (row.TABLE_SCHEMA, row.TABLE_NAME, row.CONSTRAINT_NAME) + groups.setdefault(key, []).append(row) + + for (schema_name, table_name, constraint_name), rows in groups.items(): + table = database.get_table(schema_name, table_name) + if table is None: + continue + cols: List[str] = [] + ref_cols: List[str] = [] + valid = True + for row in rows: column = table.get_column(row.COLUMN_NAME) - if column: - column.foreign_key = ForeignKey( - row.UNIQUE_TABLE_SCHEMA, - row.UNIQUE_TABLE_NAME, - row.UNIQUE_COLUMN_NAME, + if column is None: + valid = False + break + column.foreign_key = ForeignKey( + row.UNIQUE_TABLE_SCHEMA, + row.UNIQUE_TABLE_NAME, + row.UNIQUE_COLUMN_NAME, + ) + cols.append(row.COLUMN_NAME) + ref_cols.append(row.UNIQUE_COLUMN_NAME) + if valid and cols: + table.foreign_key_constraints.append( + ForeignKeyConstraint( + name=constraint_name, + ref_schema=rows[0].UNIQUE_TABLE_SCHEMA, + ref_table=rows[0].UNIQUE_TABLE_NAME, + columns=tuple(cols), + ref_columns=tuple(ref_cols), ) + ) def _read_indexes(self, database: Database) -> None: r: engine.Result[Any] = self._ensure_session().execute( diff --git a/db2sql/infrastructure/persistence/mysql/reader.py b/db2sql/infrastructure/persistence/mysql/reader.py index 015ffc2..3954ecd 100644 --- a/db2sql/infrastructure/persistence/mysql/reader.py +++ b/db2sql/infrastructure/persistence/mysql/reader.py @@ -2,13 +2,21 @@ from __future__ import annotations -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple +from urllib.parse import quote_plus from sqlalchemy import create_engine, engine, text from sqlalchemy.orm.session import Session, sessionmaker from db2sql.application.ports import Logger -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import ( + Column, + Database, + ForeignKey, + ForeignKeyConstraint, + Schema, + Table, +) from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.persistence import query_introspection from db2sql.infrastructure.persistence.errors import SourceReaderError @@ -35,8 +43,8 @@ def _connection_string(self) -> str: server = self._config.server port = f":{server.port}" if server.port else "" return "mysql+pymysql://{}:{}@{}{}/{}".format( - server.username or "", - server.password or "", + quote_plus(server.username or ""), + quote_plus(server.password or ""), server.hostname or "", port, server.dbname or "", @@ -129,24 +137,49 @@ def _read_constraints(self, database: Database) -> None: def _read_foreign_keys(self, database: Database) -> None: rows = self._ensure_session().execute( text( - "SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME " + "SELECT CONSTRAINT_NAME, TABLE_NAME, COLUMN_NAME, " + "REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME " "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE " - "WHERE TABLE_SCHEMA = :schema AND REFERENCED_TABLE_NAME IS NOT NULL" + "WHERE TABLE_SCHEMA = :schema AND REFERENCED_TABLE_NAME IS NOT NULL " + "ORDER BY CONSTRAINT_NAME, ORDINAL_POSITION" ), {"schema": self._database_name}, ) + + groups: Dict[Tuple[str, str], List[Any]] = {} for row in rows: - table = database.get_table(self._database_name, row.TABLE_NAME) + key = (row.TABLE_NAME, row.CONSTRAINT_NAME) + groups.setdefault(key, []).append(row) + + for (table_name, constraint_name), fk_rows in groups.items(): + table = database.get_table(self._database_name, table_name) if table is None: continue - column = table.get_column(row.COLUMN_NAME) - if column is None: - continue - column.foreign_key = ForeignKey( - self._database_name, - row.REFERENCED_TABLE_NAME, - row.REFERENCED_COLUMN_NAME, - ) + cols: List[str] = [] + ref_cols: List[str] = [] + valid = True + for row in fk_rows: + column = table.get_column(row.COLUMN_NAME) + if column is None: + valid = False + break + column.foreign_key = ForeignKey( + self._database_name, + row.REFERENCED_TABLE_NAME, + row.REFERENCED_COLUMN_NAME, + ) + cols.append(row.COLUMN_NAME) + ref_cols.append(row.REFERENCED_COLUMN_NAME) + if valid and cols: + table.foreign_key_constraints.append( + ForeignKeyConstraint( + name=constraint_name, + ref_schema=self._database_name, + ref_table=fk_rows[0].REFERENCED_TABLE_NAME, + columns=tuple(cols), + ref_columns=tuple(ref_cols), + ) + ) def _read_indexes(self, database: Database) -> None: rows = self._ensure_session().execute( diff --git a/db2sql/infrastructure/persistence/oracle/reader.py b/db2sql/infrastructure/persistence/oracle/reader.py index 0c45736..8354b0a 100644 --- a/db2sql/infrastructure/persistence/oracle/reader.py +++ b/db2sql/infrastructure/persistence/oracle/reader.py @@ -3,12 +3,20 @@ from __future__ import annotations from typing import Any, Dict, Iterator, List, Optional, Tuple +from urllib.parse import quote_plus from sqlalchemy import create_engine, engine, text from sqlalchemy.orm.session import Session, sessionmaker from db2sql.application.ports import Logger -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import ( + Column, + Database, + ForeignKey, + ForeignKeyConstraint, + Schema, + Table, +) from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.persistence import query_introspection from db2sql.infrastructure.persistence.errors import SourceReaderError @@ -102,7 +110,9 @@ def _connection_string(self) -> str: options = server.options or {} driver = options.get("driver", "oracledb") port = f":{server.port}" if server.port else "" - userinfo = "{}:{}".format(server.username or "", server.password or "") + userinfo = "{}:{}".format( + quote_plus(server.username or ""), quote_plus(server.password or "") + ) host = server.hostname or "" service_name = options.get("service_name") sid = options.get("sid") @@ -272,6 +282,7 @@ def _read_foreign_keys(self, database: Database) -> None: rows = self._ensure_session().execute( text( "SELECT cc.OWNER, cc.TABLE_NAME, cc.COLUMN_NAME, cc.POSITION, " + " cc.CONSTRAINT_NAME, " " rc.OWNER AS REF_OWNER, rc.TABLE_NAME AS REF_TABLE, " " rc.COLUMN_NAME AS REF_COLUMN " "FROM ALL_CONSTRAINTS c " @@ -288,14 +299,37 @@ def _read_foreign_keys(self, database: Database) -> None: ), params, ) + + groups: Dict[Tuple[str, str, str], List[Any]] = {} for row in rows: - table = database.get_table(row.owner, row.table_name) + key = (row.owner, row.table_name, row.constraint_name) + groups.setdefault(key, []).append(row) + + for (schema_name, table_name, constraint_name), fk_rows in groups.items(): + table = database.get_table(schema_name, table_name) if table is None: continue - column = table.get_column(row.column_name) - if column is None: - continue - column.foreign_key = ForeignKey(row.ref_owner, row.ref_table, row.ref_column) + cols: List[str] = [] + ref_cols: List[str] = [] + valid = True + for row in fk_rows: + column = table.get_column(row.column_name) + if column is None: + valid = False + break + column.foreign_key = ForeignKey(row.ref_owner, row.ref_table, row.ref_column) + cols.append(row.column_name) + ref_cols.append(row.ref_column) + if valid and cols: + table.foreign_key_constraints.append( + ForeignKeyConstraint( + name=constraint_name, + ref_schema=fk_rows[0].ref_owner, + ref_table=fk_rows[0].ref_table, + columns=tuple(cols), + ref_columns=tuple(ref_cols), + ) + ) def _read_indexes(self, database: Database) -> None: owner = self._schema_filter diff --git a/db2sql/infrastructure/persistence/postgres/reader.py b/db2sql/infrastructure/persistence/postgres/reader.py index 3672c23..be33fc8 100644 --- a/db2sql/infrastructure/persistence/postgres/reader.py +++ b/db2sql/infrastructure/persistence/postgres/reader.py @@ -2,13 +2,21 @@ from __future__ import annotations -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple +from urllib.parse import quote_plus from sqlalchemy import create_engine, engine, text from sqlalchemy.orm.session import Session, sessionmaker from db2sql.application.ports import Logger -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import ( + Column, + Database, + ForeignKey, + ForeignKeyConstraint, + Schema, + Table, +) from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.persistence import query_introspection from db2sql.infrastructure.persistence.errors import SourceReaderError @@ -37,8 +45,8 @@ def _connection_string(self) -> str: server = self._config.server port = f":{server.port}" if server.port else "" return "postgresql+psycopg2://{}:{}@{}{}/{}".format( - server.username or "", - server.password or "", + quote_plus(server.username or ""), + quote_plus(server.password or ""), server.hostname or "", port, server.dbname or "", @@ -123,7 +131,8 @@ def _read_constraints(self, database: Database) -> None: def _read_foreign_keys(self, database: Database) -> None: rows = self._ensure_session().execute( text( - "SELECT k1.TABLE_SCHEMA, k1.TABLE_NAME, k1.COLUMN_NAME, " + "SELECT rc.CONSTRAINT_NAME, " + " k1.TABLE_SCHEMA, k1.TABLE_NAME, k1.COLUMN_NAME, " " k2.TABLE_SCHEMA AS REF_SCHEMA, k2.TABLE_NAME AS REF_TABLE, " " k2.COLUMN_NAME AS REF_COLUMN " "FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc " @@ -134,17 +143,41 @@ def _read_foreign_keys(self, database: Database) -> None: " ON k2.CONSTRAINT_NAME = rc.UNIQUE_CONSTRAINT_NAME " " AND k2.CONSTRAINT_SCHEMA = rc.UNIQUE_CONSTRAINT_SCHEMA " " AND k1.ORDINAL_POSITION = k2.ORDINAL_POSITION " - f"WHERE k1.TABLE_SCHEMA NOT IN {_SYSTEM_SCHEMAS}" + f"WHERE k1.TABLE_SCHEMA NOT IN {_SYSTEM_SCHEMAS} " + "ORDER BY rc.CONSTRAINT_NAME, k1.ORDINAL_POSITION" ) ) + + groups: Dict[Tuple[str, str, str], List[Any]] = {} for row in rows: - table = database.get_table(row.table_schema, row.table_name) + key = (row.table_schema, row.table_name, row.constraint_name) + groups.setdefault(key, []).append(row) + + for (schema_name, table_name, constraint_name), fk_rows in groups.items(): + table = database.get_table(schema_name, table_name) if table is None: continue - column = table.get_column(row.column_name) - if column is None: - continue - column.foreign_key = ForeignKey(row.ref_schema, row.ref_table, row.ref_column) + cols: List[str] = [] + ref_cols: List[str] = [] + valid = True + for row in fk_rows: + column = table.get_column(row.column_name) + if column is None: + valid = False + break + column.foreign_key = ForeignKey(row.ref_schema, row.ref_table, row.ref_column) + cols.append(row.column_name) + ref_cols.append(row.ref_column) + if valid and cols: + table.foreign_key_constraints.append( + ForeignKeyConstraint( + name=constraint_name, + ref_schema=fk_rows[0].ref_schema, + ref_table=fk_rows[0].ref_table, + columns=tuple(cols), + ref_columns=tuple(ref_cols), + ) + ) def _read_indexes(self, database: Database) -> None: rows = self._ensure_session().execute( diff --git a/db2sql/infrastructure/persistence/sqlite/reader.py b/db2sql/infrastructure/persistence/sqlite/reader.py index 3a19fe8..95adeab 100644 --- a/db2sql/infrastructure/persistence/sqlite/reader.py +++ b/db2sql/infrastructure/persistence/sqlite/reader.py @@ -2,13 +2,20 @@ from __future__ import annotations -from typing import Any, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from sqlalchemy import create_engine, engine, text from sqlalchemy.orm.session import Session, sessionmaker from db2sql.application.ports import Logger -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import ( + Column, + Database, + ForeignKey, + ForeignKeyConstraint, + Schema, + Table, +) from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.persistence import query_introspection from db2sql.infrastructure.persistence.errors import SourceReaderError @@ -117,12 +124,36 @@ def _read_foreign_keys(self, database: Database, table_name: str) -> None: table = database.get_table(self._schema, table_name) if table is None: return + + # Group rows by constraint id (first element of each PRAGMA row) + groups: Dict[int, List[Tuple[str, str, str]]] = {} for row in rows: - _, _, ref_table, src_col, ref_col, *_ = row - column = table.get_column(src_col) - if column is None: - continue - column.foreign_key = ForeignKey(self._schema, ref_table, ref_col) + fk_id, _, ref_table, src_col, ref_col, *_ = row + groups.setdefault(fk_id, []).append((ref_table, src_col, ref_col)) + + for fk_id, fk_rows in groups.items(): + cols: List[str] = [] + ref_cols: List[str] = [] + valid = True + ref_table = fk_rows[0][0] + for _, src_col, ref_col in fk_rows: + column = table.get_column(src_col) + if column is None: + valid = False + break + column.foreign_key = ForeignKey(self._schema, ref_table, ref_col) + cols.append(src_col) + ref_cols.append(ref_col) + if valid and cols: + table.foreign_key_constraints.append( + ForeignKeyConstraint( + name=f"{table_name}_fk_{fk_id}", + ref_schema=self._schema, + ref_table=ref_table, + columns=tuple(cols), + ref_columns=tuple(ref_cols), + ) + ) def iter_rows(self, schema: str, table: Table, limit: int = -1) -> Iterator[Tuple[Any, ...]]: session = self._ensure_session() diff --git a/db2sql/infrastructure/writer/mssql/writer.py b/db2sql/infrastructure/writer/mssql/writer.py index 3def634..5538fbd 100644 --- a/db2sql/infrastructure/writer/mssql/writer.py +++ b/db2sql/infrastructure/writer/mssql/writer.py @@ -12,12 +12,14 @@ from types import TracebackType from typing import Any, Iterator, List, Optional, Tuple +from urllib.parse import quote_plus from sqlalchemy import create_engine, text from sqlalchemy.engine import Connection, Engine from db2sql.application.ports import Logger from db2sql.domain.model import Table +from db2sql.domain.policy import normalize_identifier from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.writer.errors import ( TargetWriterConnectionError, @@ -37,6 +39,7 @@ class MssqlTargetWriter: def __init__(self, config: AppConfig, logger: Logger) -> None: self._config = config self._logger = logger + self._preserve_case: bool = config.dump.preserve_case self._engine: Optional[Engine] = None self._connection: Optional[Connection] = None self._batch_size = max(1, config.migrate.batch_size) @@ -131,8 +134,8 @@ def _connection_string(self) -> str: server = self._config.target_server port = f":{server.port}" if server.port else "" return "mssql+pymssql://{}:{}@{}{}/{}".format( - server.username or "", - server.password or "", + quote_plus(server.username or ""), + quote_plus(server.password or ""), server.hostname or "", port, server.dbname or "", @@ -149,6 +152,6 @@ def _connection_string_redacted(self) -> str: server.dbname or "", ) - @staticmethod - def _quote_ident(name: str) -> str: - return "[{}]".format(name.replace("]", "]]")) + def _quote_ident(self, name: str) -> str: + normalized = normalize_identifier(name, self._preserve_case) + return "[{}]".format(normalized.replace("]", "]]")) diff --git a/db2sql/infrastructure/writer/postgres/writer.py b/db2sql/infrastructure/writer/postgres/writer.py index e1364a7..6af8c74 100644 --- a/db2sql/infrastructure/writer/postgres/writer.py +++ b/db2sql/infrastructure/writer/postgres/writer.py @@ -13,12 +13,14 @@ import io from types import TracebackType from typing import Any, Iterator, Optional, Tuple +from urllib.parse import quote_plus from sqlalchemy import create_engine, text from sqlalchemy.engine import Connection, Engine from db2sql.application.ports import Logger from db2sql.domain.model import Table +from db2sql.domain.policy import normalize_identifier from db2sql.infrastructure.config import AppConfig from db2sql.infrastructure.writer.errors import ( TargetWriterConnectionError, @@ -38,6 +40,7 @@ class PostgresTargetWriter: def __init__(self, config: AppConfig, logger: Logger) -> None: self._config = config self._logger = logger + self._preserve_case: bool = config.dump.preserve_case self._engine: Optional[Engine] = None self._connection: Optional[Connection] = None @@ -129,8 +132,8 @@ def _connection_string(self) -> str: server = self._config.target_server port = f":{server.port}" if server.port else "" return "postgresql+psycopg2://{}:{}@{}{}/{}".format( - server.username or "", - server.password or "", + quote_plus(server.username or ""), + quote_plus(server.password or ""), server.hostname or "", port, server.dbname or "", @@ -147,9 +150,9 @@ def _connection_string_redacted(self) -> str: server.dbname or "", ) - @staticmethod - def _quote_ident(name: str) -> str: - return '"{}"'.format(name.replace('"', '""')) + def _quote_ident(self, name: str) -> str: + normalized = normalize_identifier(name, self._preserve_case) + return '"{}"'.format(normalized.replace('"', '""')) # Mirrors PostgresSqlEmitter._format_copy_value so the bytes hitting the # server through COPY FROM STDIN are byte-identical to the file dump. diff --git a/db2sql/interface/cli/parser.py b/db2sql/interface/cli/parser.py index 4f6eabf..126097d 100644 --- a/db2sql/interface/cli/parser.py +++ b/db2sql/interface/cli/parser.py @@ -312,6 +312,15 @@ def _add_dump_options(parser: argparse.ArgumentParser) -> None: default=None, help="Table names to exclude during export (repeatable, comma separated).", ) + parser.add_argument( + "--target-schema", + dest="target_schema", + metavar="NAME", + type=str, + default=None, + help="Redirect all source schemas to NAME in the target.", + action=OnceArgument, + ) parser.add_argument( "-C", "--config-file", @@ -514,6 +523,15 @@ def _add_migrate_subparser(subparsers: Any) -> None: "statement." ), ) + migrate_parser.add_argument( + "--target-schema", + dest="target_schema", + metavar="NAME", + type=str, + default=None, + help="Redirect all source schemas to NAME in the target.", + action=OnceArgument, + ) def _add_init_subparser(subparsers: Any) -> None: diff --git a/tests/unit/domain/model/test_foreign_key.py b/tests/unit/domain/model/test_foreign_key.py index 6adbeb8..e0cb2ba 100644 --- a/tests/unit/domain/model/test_foreign_key.py +++ b/tests/unit/domain/model/test_foreign_key.py @@ -1,4 +1,4 @@ -"""ForeignKey is a frozen value object.""" +"""ForeignKey and ForeignKeyConstraint are frozen value objects.""" from __future__ import annotations @@ -6,7 +6,7 @@ import pytest -from db2sql.domain.model import ForeignKey +from db2sql.domain.model import ForeignKey, ForeignKeyConstraint def test_foreign_key_is_frozen() -> None: @@ -23,3 +23,25 @@ def test_equality_is_structural() -> None: def test_foreign_key_is_hashable() -> None: fk = ForeignKey("s", "t", "c") assert {fk: 1}[fk] == 1 + + +def test_foreign_key_constraint_is_frozen() -> None: + fkc = ForeignKeyConstraint( + name="fk_test", ref_schema="s", ref_table="t", + columns=("a",), ref_columns=("b",), + ) + with pytest.raises(dataclasses.FrozenInstanceError): + fkc.name = "other" # type: ignore[misc] + + +def test_foreign_key_constraint_equality() -> None: + a = ForeignKeyConstraint("fk", "s", "t", ("c1", "c2"), ("r1", "r2")) + b = ForeignKeyConstraint("fk", "s", "t", ("c1", "c2"), ("r1", "r2")) + c = ForeignKeyConstraint("fk", "s", "t", ("c1",), ("r1",)) + assert a == b + assert a != c + + +def test_foreign_key_constraint_is_hashable() -> None: + fkc = ForeignKeyConstraint("fk", "s", "t", ("c1",), ("r1",)) + assert {fkc: 1}[fkc] == 1 diff --git a/tests/unit/infrastructure/emit/test_mssql_emitter.py b/tests/unit/infrastructure/emit/test_mssql_emitter.py index acb131b..d11c393 100644 --- a/tests/unit/infrastructure/emit/test_mssql_emitter.py +++ b/tests/unit/infrastructure/emit/test_mssql_emitter.py @@ -7,7 +7,7 @@ import pytest -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import Column, Database, ForeignKey, ForeignKeyConstraint, Schema, Table from db2sql.infrastructure.emit.mssql import MssqlSqlEmitter @@ -183,6 +183,9 @@ def test_emit_foreign_keys_with_valid_reference(self) -> None: col = Column(name="author_id", type="int") col.foreign_key = ForeignKey("public", "author", "id") book.add_column(col) + book.foreign_key_constraints.append( + ForeignKeyConstraint("fk_book_author", "public", "author", ("author_id",), ("id",)) + ) db.schemas["public"].add_table(book) sink = _Sink() emitter.emit_foreign_keys(db, sink) @@ -197,11 +200,31 @@ def test_emit_foreign_keys_skips_dangling_refs(self) -> None: col = Column(name="author_id", type="int") col.foreign_key = ForeignKey("missing", "author", "id") book.add_column(col) + book.foreign_key_constraints.append( + ForeignKeyConstraint("fk_book_author", "missing", "author", ("author_id",), ("id",)) + ) db.schemas["public"].add_table(book) sink = _Sink() emitter.emit_foreign_keys(db, sink) assert "ALTER TABLE" not in sink.text + def test_emit_foreign_keys_composite(self) -> None: + emitter = MssqlSqlEmitter(preserve_case=True) + db = self._db() + child = Table(name="child") + child.add_column(Column(name="a", type="int")) + child.add_column(Column(name="b", type="int")) + child.foreign_key_constraints.append( + ForeignKeyConstraint( + "fk_child_author", "public", "author", ("a", "b"), ("id", "name") + ) + ) + db.schemas["public"].add_table(child) + sink = _Sink() + emitter.emit_foreign_keys(db, sink) + assert "ADD FOREIGN KEY ([a], [b])" in sink.text + assert "REFERENCES [public].[author] ([id], [name])" in sink.text + def test_emit_indexes(self) -> None: emitter = MssqlSqlEmitter(preserve_case=True) db = self._db() diff --git a/tests/unit/infrastructure/emit/test_postgres_emitter.py b/tests/unit/infrastructure/emit/test_postgres_emitter.py index ea5db65..4e73e16 100644 --- a/tests/unit/infrastructure/emit/test_postgres_emitter.py +++ b/tests/unit/infrastructure/emit/test_postgres_emitter.py @@ -6,7 +6,7 @@ import pytest -from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table +from db2sql.domain.model import Column, Database, ForeignKey, ForeignKeyConstraint, Schema, Table from db2sql.infrastructure.emit.postgres import PostgresSqlEmitter @@ -137,6 +137,9 @@ def test_emit_foreign_keys_skips_dangling_refs(self) -> None: col = Column(name="author_id", type="int") col.foreign_key = ForeignKey("missing", "author", "id") book.add_column(col) + book.foreign_key_constraints.append( + ForeignKeyConstraint("fk_book_author", "missing", "author", ("author_id",), ("id",)) + ) db.schemas["public"].add_table(book) sink = _Sink() @@ -150,6 +153,9 @@ def test_emit_foreign_keys_with_valid_reference(self) -> None: col = Column(name="author_id", type="int") col.foreign_key = ForeignKey("public", "author", "id") book.add_column(col) + book.foreign_key_constraints.append( + ForeignKeyConstraint("fk_book_author", "public", "author", ("author_id",), ("id",)) + ) db.schemas["public"].add_table(book) sink = _Sink() emitter.emit_foreign_keys(db, sink) @@ -244,6 +250,23 @@ def test_emit_truncates_respects_schema_mapping(self) -> None: emitter.emit_truncates(db, sink) assert '"public"."author"' in sink.text + def test_emit_foreign_keys_composite(self) -> None: + emitter = PostgresSqlEmitter(preserve_case=True) + db = self._db() + child = Table(name="child") + child.add_column(Column(name="a", type="int")) + child.add_column(Column(name="b", type="int")) + child.foreign_key_constraints.append( + ForeignKeyConstraint( + "fk_child_author", "public", "author", ("a", "b"), ("id", "name") + ) + ) + db.schemas["public"].add_table(child) + sink = _Sink() + emitter.emit_foreign_keys(db, sink) + assert 'ADD FOREIGN KEY ("a", "b")' in sink.text + assert 'REFERENCES "public"."author" ("id", "name")' in sink.text + def test_emit_foreign_keys_skips_when_ref_table_missing(self) -> None: emitter = PostgresSqlEmitter(preserve_case=True) db = self._db() @@ -252,6 +275,11 @@ def test_emit_foreign_keys_skips_when_ref_table_missing(self) -> None: col = Column(name="author_id", type="int") col.foreign_key = ForeignKey("public", "no_such_table", "id") book.add_column(col) + book.foreign_key_constraints.append( + ForeignKeyConstraint( + "fk_book_nosuch", "public", "no_such_table", ("author_id",), ("id",) + ) + ) db.schemas["public"].add_table(book) sink = _Sink() emitter.emit_foreign_keys(db, sink) diff --git a/tests/unit/infrastructure/persistence/test_mssql_reader.py b/tests/unit/infrastructure/persistence/test_mssql_reader.py index ee50d0d..dcc69df 100644 --- a/tests/unit/infrastructure/persistence/test_mssql_reader.py +++ b/tests/unit/infrastructure/persistence/test_mssql_reader.py @@ -137,6 +137,7 @@ def _populated_session() -> FakeSession: "REFERENTIAL_CONSTRAINTS", [ FakeRow( + CONSTRAINT_NAME="FK_Order_Customer", TABLE_SCHEMA="dbo", TABLE_NAME="Order", COLUMN_NAME="CustomerId", @@ -145,6 +146,7 @@ def _populated_session() -> FakeSession: UNIQUE_COLUMN_NAME="Id", ), FakeRow( + CONSTRAINT_NAME="FK_Order_Missing", TABLE_SCHEMA="dbo", TABLE_NAME="Order", COLUMN_NAME="Missing", @@ -153,6 +155,7 @@ def _populated_session() -> FakeSession: UNIQUE_COLUMN_NAME="Id", ), FakeRow( + CONSTRAINT_NAME="FK_Phantom_Customer", TABLE_SCHEMA="dbo", TABLE_NAME="Phantom", COLUMN_NAME="X", diff --git a/tests/unit/infrastructure/persistence/test_mysql_reader.py b/tests/unit/infrastructure/persistence/test_mysql_reader.py index 533a99d..449d1a4 100644 --- a/tests/unit/infrastructure/persistence/test_mysql_reader.py +++ b/tests/unit/infrastructure/persistence/test_mysql_reader.py @@ -104,18 +104,21 @@ def _full_plan() -> FakeSession: "REFERENCED_TABLE_NAME IS NOT NULL", [ FakeRow( + CONSTRAINT_NAME="fk_book_author", TABLE_NAME="book", COLUMN_NAME="author_id", REFERENCED_TABLE_NAME="author", REFERENCED_COLUMN_NAME="id", ), FakeRow( + CONSTRAINT_NAME="fk_book_missing", TABLE_NAME="book", COLUMN_NAME="missing_col", REFERENCED_TABLE_NAME="author", REFERENCED_COLUMN_NAME="id", ), FakeRow( + CONSTRAINT_NAME="fk_ghost_author", TABLE_NAME="ghost", COLUMN_NAME="x", REFERENCED_TABLE_NAME="author", diff --git a/tests/unit/infrastructure/persistence/test_oracle_reader.py b/tests/unit/infrastructure/persistence/test_oracle_reader.py index 12ef905..69514e4 100644 --- a/tests/unit/infrastructure/persistence/test_oracle_reader.py +++ b/tests/unit/infrastructure/persistence/test_oracle_reader.py @@ -94,6 +94,7 @@ def _full_plan() -> FakeSession: table_name="EMP", column_name="DEPT_ID", position=1, + constraint_name="FK_EMP_DEPT", ref_owner="HR", ref_table="DEPT", ref_column="ID", @@ -374,6 +375,7 @@ def test_collect_metadata_skips_rows_for_missing_columns_and_tables() -> None: table_name="EMP", column_name="MISSING", position=1, + constraint_name="FK_EMP_MISSING", ref_owner="HR", ref_table="DEPT", ref_column="ID", @@ -383,6 +385,7 @@ def test_collect_metadata_skips_rows_for_missing_columns_and_tables() -> None: table_name="GHOST", column_name="X", position=1, + constraint_name="FK_GHOST_DEPT", ref_owner="HR", ref_table="DEPT", ref_column="ID", diff --git a/tests/unit/infrastructure/persistence/test_postgres_reader.py b/tests/unit/infrastructure/persistence/test_postgres_reader.py index e10f53c..187bc3d 100644 --- a/tests/unit/infrastructure/persistence/test_postgres_reader.py +++ b/tests/unit/infrastructure/persistence/test_postgres_reader.py @@ -100,6 +100,7 @@ def _populated_session() -> FakeSession: "REFERENTIAL_CONSTRAINTS", [ FakeRow( + constraint_name="fk_book_author", table_schema="public", table_name="book", column_name="author_id", @@ -107,8 +108,9 @@ def _populated_session() -> FakeSession: ref_table="author", ref_column="id", ), - # column missing → ignored + # column missing → ignored (separate constraint) FakeRow( + constraint_name="fk_book_zzz", table_schema="public", table_name="book", column_name="zzz", @@ -118,6 +120,7 @@ def _populated_session() -> FakeSession: ), # table missing FakeRow( + constraint_name="fk_phantom_x", table_schema="public", table_name="phantom", column_name="x",