Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion db2sql/application/use_cases/migrate_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def _load_data(self, database: Database) -> None:
rows = self._reader.iter_query_rows(table.source_query, limit=limit)
else:
rows = self._reader.iter_rows(schema_name, table, limit=limit)
self._writer.bulk_load(schema_name, table, rows)
mapped = options.mapping_schemas.get(
schema_name, options.mapping_schemas.get("*", schema_name)
)
self._writer.bulk_load(mapped, table, rows)

@staticmethod
def _resolve_limit(
Expand Down
4 changes: 2 additions & 2 deletions db2sql/domain/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from .column import Column
from .database import Database
from .foreign_key import ForeignKey
from .foreign_key import ForeignKey, ForeignKeyConstraint
from .schema import Schema
from .table import Table

__all__ = ["Column", "Database", "ForeignKey", "Schema", "Table"]
__all__ = ["Column", "Database", "ForeignKey", "ForeignKeyConstraint", "Schema", "Table"]
14 changes: 13 additions & 1 deletion db2sql/domain/model/foreign_key.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Foreign key value object."""
"""Foreign key value objects."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple


@dataclass(frozen=True)
Expand All @@ -12,3 +13,14 @@ class ForeignKey:
schema: str
table: str
column: str


@dataclass(frozen=True)
class ForeignKeyConstraint:
"""A named FK constraint grouping one or more columns. Immutable value object."""

name: str
ref_schema: str
ref_table: str
columns: Tuple[str, ...]
ref_columns: Tuple[str, ...]
2 changes: 2 additions & 0 deletions db2sql/domain/model/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from db2sql.domain.errors import DuplicatedColumnError

from .column import Column
from .foreign_key import ForeignKeyConstraint


@dataclass
Expand All @@ -17,6 +18,7 @@ class Table:
name: str
columns: Dict[str, Column] = field(default_factory=dict)
indexes: Dict[str, List[str]] = field(default_factory=dict)
foreign_key_constraints: List[ForeignKeyConstraint] = field(default_factory=list)
source_query: Optional[str] = None

def add_column(self, column: Column) -> None:
Expand Down
9 changes: 8 additions & 1 deletion db2sql/domain/policy/identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@


def to_snake_case(name: str) -> str:
"""Convert CamelCase to snake_case (idempotent on already-snake_case input)."""
"""Convert CamelCase to snake_case (idempotent on already-snake_case input).

Pure-uppercase names (e.g. ``CLIENTS``) are lowered without inserting
underscores, so ``CLIENTS`` becomes ``clients`` rather than ``c_l_i_e_n_t_s``.
Mixed-case names like ``ClientName`` become ``client_name``.
"""
if name.isupper() or name.islower():
return name.lower()
return _CAMEL_RE.sub("_", name).lower()


Expand Down
3 changes: 3 additions & 0 deletions db2sql/infrastructure/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def merge_cli_overrides(config: AppConfig, options: Mapping[str, Any]) -> AppCon
dump["default_data_format"] = value
elif key == "output_file_name":
data["output_file"] = value
elif key == "target_schema":
schemas = dump.setdefault("mapping_schemas", {})
schemas["*"] = value
elif key == "split_size":
data["split_size"] = value

Expand Down
22 changes: 11 additions & 11 deletions db2sql/infrastructure/emit/mssql/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def quote_identifier(self, name: str) -> str:
return "[{}]".format(normalized.replace("]", "]]"))

def schema_name(self, schema: Schema) -> str:
mapped = self._schema_mapping.get(schema.name, schema.name)
mapped = self._schema_mapping.get(schema.name, self._schema_mapping.get("*", schema.name))
return self.quote_identifier(mapped)

def table_name(self, schema: Schema, table: Table) -> str:
Expand Down Expand Up @@ -148,7 +148,9 @@ def emit_epilogue(self, sink: OutputSink) -> None:
def emit_schemas(self, database: Database, sink: OutputSink) -> None:
emitted = set()
for schema in database.schemas.values():
target = self._schema_mapping.get(schema.name, schema.name)
target = self._schema_mapping.get(
schema.name, self._schema_mapping.get("*", schema.name)
)
if target in emitted:
continue
emitted.add(target)
Expand Down Expand Up @@ -213,22 +215,20 @@ def emit_foreign_keys(self, database: Database, sink: OutputSink) -> None:
for schema in database.schemas.values():
for table in schema.tables.values():
qualified = self.table_name(schema, table)
for column in table.columns.values():
fk = column.foreign_key
if not fk:
continue
ref_schema = database.schemas.get(fk.schema)
for fkc in table.foreign_key_constraints:
ref_schema = database.schemas.get(fkc.ref_schema)
if ref_schema is None:
continue
ref_table = ref_schema.get_table(fk.table)
ref_table = ref_schema.get_table(fkc.ref_table)
if ref_table is None:
continue
ref_qualified = self.table_name(ref_schema, ref_table)
cols = ", ".join(self.quote_identifier(c) for c in fkc.columns)
ref_cols = ", ".join(self.quote_identifier(c) for c in fkc.ref_columns)
sink.write(
f"ALTER TABLE {qualified} "
f"ADD FOREIGN KEY ({self.quote_identifier(column.name)}) "
f"REFERENCES {ref_qualified} "
f"({self.quote_identifier(fk.column)});\n"
f"ADD FOREIGN KEY ({cols}) "
f"REFERENCES {ref_qualified} ({ref_cols});\n"
)
sink.boundary()
sink.write("\n")
Expand Down
62 changes: 49 additions & 13 deletions db2sql/infrastructure/emit/postgres/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@

from __future__ import annotations

import re
from typing import Any, Dict, Iterable, Mapping, Optional

from db2sql.application.ports import OutputSink
from db2sql.domain.model import Column, Database, Schema, Table
from db2sql.domain.policy import drop_order, normalize_identifier, topological_order

_DEFAULT_FUNCTION_MAP: Dict[str, str] = {
"getdate()": "now()",
"getutcdate()": "(now() AT TIME ZONE 'UTC')",
"sysdatetime()": "now()",
"sysutcdatetime()": "(now() AT TIME ZONE 'UTC')",
"newid()": "gen_random_uuid()",
"newsequentialid()": "gen_random_uuid()",
}

_DEFAULT_FUNCTION_RE = re.compile(
"|".join(re.escape(k) for k in _DEFAULT_FUNCTION_MAP),
re.IGNORECASE,
)


class PostgresSqlEmitter:
"""Produce PostgreSQL DDL+DML for a collected :class:`Database`."""
Expand Down Expand Up @@ -93,7 +108,7 @@ def quote_identifier(self, name: str) -> str:
return '"{}"'.format(normalized.replace('"', '""'))

def schema_name(self, schema: Schema) -> str:
mapped = self._schema_mapping.get(schema.name, schema.name)
mapped = self._schema_mapping.get(schema.name, self._schema_mapping.get("*", schema.name))
return self.quote_identifier(mapped)

def table_name(self, schema: Schema, table: Table) -> str:
Expand All @@ -106,7 +121,17 @@ def _map_type(self, column: Column) -> str:
if column.char_length and column.char_length > 0:
return f"{target}({column.char_length})"
if target == "numeric" and column.precision:
scale = column.scale or 0
scale = column.scale if column.scale is not None else 0
if column.scale is not None and scale == 0:
# numeric(p,0) is an integer stored in a decimal type (common
# MSSQL pattern). Promote to a native integer type so that
# serial/bigserial works for identity columns and FK types
# stay consistent across referencing tables.
if column.precision > 9:
return "bigint"
if column.precision > 4:
return "integer"
return "smallint"
return f"numeric({column.precision},{scale})"
return target

Expand All @@ -118,9 +143,20 @@ def column_definition(self, column: Column) -> str:
if not column.nullable and not column.identity:
parts.append("NOT NULL")
if column.default is not None and not column.identity:
parts.append(f"DEFAULT {column.default}")
parts.append(f"DEFAULT {self._translate_default(column.default)}")
return " ".join(parts)

@staticmethod
def _translate_default(default: str) -> str:
"""Translate source-dialect default expressions to PostgreSQL."""
# Strip wrapping parentheses added by MSSQL (e.g. "(getdate())")
stripped = default.strip()
while stripped.startswith("(") and stripped.endswith(")"):
stripped = stripped[1:-1].strip()
return _DEFAULT_FUNCTION_RE.sub(
lambda m: _DEFAULT_FUNCTION_MAP[m.group(0).lower()], stripped
)

# ---- emit -------------------------------------------------------------

def emit_prologue(self, sink: OutputSink) -> None:
Expand All @@ -134,7 +170,9 @@ def emit_epilogue(self, sink: OutputSink) -> None:
def emit_schemas(self, database: Database, sink: OutputSink) -> None:
emitted = set()
for schema in database.schemas.values():
target = self._schema_mapping.get(schema.name, schema.name)
target = self._schema_mapping.get(
schema.name, self._schema_mapping.get("*", schema.name)
)
if target in emitted:
continue
emitted.add(target)
Expand Down Expand Up @@ -190,22 +228,20 @@ def emit_foreign_keys(self, database: Database, sink: OutputSink) -> None:
for schema in database.schemas.values():
for table in schema.tables.values():
qualified = self.table_name(schema, table)
for column in table.columns.values():
fk = column.foreign_key
if not fk:
continue
ref_schema = database.schemas.get(fk.schema)
for fkc in table.foreign_key_constraints:
ref_schema = database.schemas.get(fkc.ref_schema)
if ref_schema is None:
continue
ref_table = ref_schema.get_table(fk.table)
ref_table = ref_schema.get_table(fkc.ref_table)
if ref_table is None:
continue
ref_qualified = self.table_name(ref_schema, ref_table)
cols = ", ".join(self.quote_identifier(c) for c in fkc.columns)
ref_cols = ", ".join(self.quote_identifier(c) for c in fkc.ref_columns)
sink.write(
f"ALTER TABLE {qualified} "
f"ADD FOREIGN KEY ({self.quote_identifier(column.name)}) "
f"REFERENCES {ref_qualified} "
f"({self.quote_identifier(fk.column)});\n"
f"ADD FOREIGN KEY ({cols}) "
f"REFERENCES {ref_qualified} ({ref_cols});\n"
)
sink.boundary()
sink.write("\n")
Expand Down
57 changes: 45 additions & 12 deletions db2sql/infrastructure/persistence/mssql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@

from __future__ import annotations

from typing import Any, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple
from urllib.parse import quote_plus

from sqlalchemy import create_engine, engine, text
from sqlalchemy.orm.session import Session, sessionmaker

from db2sql.application.ports import Logger
from db2sql.domain.model import Column, Database, ForeignKey, Schema, Table
from db2sql.domain.model import (
Column,
Database,
ForeignKey,
ForeignKeyConstraint,
Schema,
Table,
)
from db2sql.infrastructure.config import AppConfig
from db2sql.infrastructure.persistence import query_introspection
from db2sql.infrastructure.persistence.errors import SourceReaderError
Expand All @@ -35,8 +43,8 @@ def _connection_string(self) -> str:
server = self._config.server
port = f":{server.port}" if server.port else ""
return "mssql+pymssql://{}:{}@{}{}/{}".format(
server.username or "",
server.password or "",
quote_plus(server.username or ""),
quote_plus(server.password or ""),
server.hostname or "",
port,
server.dbname or "",
Expand Down Expand Up @@ -254,21 +262,46 @@ def _read_foreign_keys(self, database: Database) -> None:
AND KCU2.CONSTRAINT_NAME = RC.UNIQUE_CONSTRAINT_NAME
WHERE KCU1.ORDINAL_POSITION = KCU2.ORDINAL_POSITION
AND KCU1.TABLE_SCHEMA not in ('sys', 'guest', 'information_schema')
ORDER BY CONSTRAINT_SCHEMA, CONSTRAINT_NAME
ORDER BY CONSTRAINT_SCHEMA, CONSTRAINT_NAME, KCU1.ORDINAL_POSITION
"""
)
)

# Group FK rows by (schema, table, constraint_name)
groups: Dict[Tuple[str, str, str], List[Any]] = {}
for row in r:
table = database.get_table(row.TABLE_SCHEMA, row.TABLE_NAME)
if table:
key = (row.TABLE_SCHEMA, row.TABLE_NAME, row.CONSTRAINT_NAME)
groups.setdefault(key, []).append(row)

for (schema_name, table_name, constraint_name), rows in groups.items():
table = database.get_table(schema_name, table_name)
if table is None:
continue
cols: List[str] = []
ref_cols: List[str] = []
valid = True
for row in rows:
column = table.get_column(row.COLUMN_NAME)
if column:
column.foreign_key = ForeignKey(
row.UNIQUE_TABLE_SCHEMA,
row.UNIQUE_TABLE_NAME,
row.UNIQUE_COLUMN_NAME,
if column is None:
valid = False
break
column.foreign_key = ForeignKey(
row.UNIQUE_TABLE_SCHEMA,
row.UNIQUE_TABLE_NAME,
row.UNIQUE_COLUMN_NAME,
)
cols.append(row.COLUMN_NAME)
ref_cols.append(row.UNIQUE_COLUMN_NAME)
if valid and cols:
table.foreign_key_constraints.append(
ForeignKeyConstraint(
name=constraint_name,
ref_schema=rows[0].UNIQUE_TABLE_SCHEMA,
ref_table=rows[0].UNIQUE_TABLE_NAME,
columns=tuple(cols),
ref_columns=tuple(ref_cols),
)
)

def _read_indexes(self, database: Database) -> None:
r: engine.Result[Any] = self._ensure_session().execute(
Expand Down
Loading