Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .docker/mssql/init/01-schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ GO
IF OBJECT_ID('apptest.book', 'U') IS NOT NULL DROP TABLE apptest.book;
IF OBJECT_ID('apptest.author', 'U') IS NOT NULL DROP TABLE apptest.author;
IF OBJECT_ID('apptest.type_matrix', 'U') IS NOT NULL DROP TABLE apptest.type_matrix;
IF OBJECT_ID('apptest.default_matrix', 'U') IS NOT NULL DROP TABLE apptest.default_matrix;
GO

-- Type-coverage table -------------------------------------------------------
Expand Down Expand Up @@ -56,6 +57,30 @@ CREATE TABLE apptest.type_matrix (
);
GO

-- Default-value coverage table -------------------------------------------
-- Every column exercises a DEFAULT expression that the postgres emitter is
-- expected to translate (functions, bare keywords, literals, booleans).
CREATE TABLE apptest.default_matrix (
id INT IDENTITY(1,1) PRIMARY KEY,
d_getdate DATETIME NOT NULL DEFAULT GETDATE(),
d_sysdatetime DATETIME2 NOT NULL DEFAULT SYSDATETIME(),
d_getutcdate DATETIME NOT NULL DEFAULT GETUTCDATE(),
d_sysutcdatetime DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(),
d_sysdatetimeoffset DATETIMEOFFSET NOT NULL DEFAULT SYSDATETIMEOFFSET(),
d_current_timestamp DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
d_newid UNIQUEIDENTIFIER NOT NULL DEFAULT NEWID(),
d_newsequentialid UNIQUEIDENTIFIER NOT NULL DEFAULT NEWSEQUENTIALID(),
d_suser_sname NVARCHAR(128) NOT NULL DEFAULT SUSER_SNAME(),
d_system_user NVARCHAR(128) NOT NULL DEFAULT SYSTEM_USER,
d_user_name NVARCHAR(128) NOT NULL DEFAULT USER_NAME(),
d_db_name NVARCHAR(128) NOT NULL DEFAULT DB_NAME(),
d_bit_true BIT NOT NULL DEFAULT 1,
d_bit_false BIT NOT NULL DEFAULT 0,
d_int_literal INT NOT NULL DEFAULT 42,
d_string_literal NVARCHAR(32) NOT NULL DEFAULT N'hello'
);
GO

-- Relational mini-fixture (parallels the sqlite fixture in tests/conftest.py)
CREATE TABLE apptest.author (
id INT IDENTITY(1,1) PRIMARY KEY,
Expand Down
105 changes: 103 additions & 2 deletions db2sql/infrastructure/emit/mssql/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,61 @@

from __future__ import annotations

import re
import warnings
from typing import Any, Dict, Iterable, Mapping, Optional

from db2sql.application.ports import OutputSink
from db2sql.domain.model import Column, Database, Schema, Table
from db2sql.domain.policy import drop_order, normalize_identifier

# Source-side scalar functions to rewrite when targeting MSSQL. Keys lowercased,
# parens stripped; matched with or without empty parens so Oracle bare keywords
# (``SYSDATE``) and PG / MSSQL function calls are both handled.
_DEFAULT_FUNCTION_MAP: Dict[str, str] = {
# PG date/time
"now": "SYSDATETIME()",
"localtimestamp": "SYSDATETIME()",
"transaction_timestamp": "SYSDATETIME()",
"statement_timestamp": "SYSDATETIME()",
"clock_timestamp": "SYSDATETIME()",
"current_date": "CAST(SYSDATETIME() AS DATE)",
"current_time": "CAST(SYSDATETIME() AS TIME)",
# MySQL date/time (NOW already covered as ``now``); UTC variant.
"utc_timestamp": "SYSUTCDATETIME()",
# Oracle date/time bare keywords
"sysdate": "GETDATE()",
"systimestamp": "SYSDATETIME()",
# uuid generators
"gen_random_uuid": "NEWID()",
"uuid_generate_v4": "NEWID()",
"sys_guid": "NEWID()", # Oracle
"uuid": "NEWID()", # MySQL
# session info — leave CURRENT_USER / SESSION_USER / SYSTEM_USER and
# CURRENT_TIMESTAMP alone, they are ANSI-compatible in MSSQL.
"current_database": "DB_NAME()",
"current_catalog": "DB_NAME()",
"current_schema": "SCHEMA_NAME()",
}

_FUNCTION_CALL_RE = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)(?:\s*\(\s*\))?\s*$")
# PG ``literal::type`` cast — only when the whole expression is a single
# literal followed by a single cast. Anything more complex is left as-is.
_PG_CAST_RE = re.compile(
r"""^\s*
(?P<value>
NULL
| TRUE | FALSE
| -?\d+(?:\.\d+)?
| '(?:[^']|'')*'
)
\s*::\s*[A-Za-z_][A-Za-z_0-9 ]*(?:\([^)]*\))?
\s*$""",
re.IGNORECASE | re.VERBOSE,
)
# MySQL ``bit`` default literal — ``b'0'`` / ``b'1'``.
_MYSQL_BIT_RE = re.compile(r"^(?i:b)'([01]+)'$")


class MssqlSqlEmitter:
"""Produce Microsoft SQL Server DDL+DML for a collected :class:`Database`."""
Expand Down Expand Up @@ -99,7 +147,8 @@ def _normalize(self, name: str) -> str:

def quote_identifier(self, name: str) -> str:
normalized = self._normalize(name)
return "[{}]".format(normalized.replace("]", "]]"))
escaped = normalized.replace("]", "]]")
return f"[{escaped}]"

def schema_name(self, schema: Schema) -> str:
mapped = self._schema_mapping.get(schema.name, schema.name)
Expand Down Expand Up @@ -132,9 +181,61 @@ def column_definition(self, column: Column) -> str:
if not column.nullable and not column.identity:
parts.append("NOT NULL")
if column.default is not None and not column.identity:
parts.append(f"DEFAULT {column.default}")
parts.append(f"DEFAULT {self._translate_default(column.default, target_type)}")
return " ".join(parts)

@staticmethod
def _strip_pg_cast(expr: str) -> str:
match = _PG_CAST_RE.match(expr)
if match:
return match.group("value")
return expr

@staticmethod
def _strip_wrapping_parens(expr: str) -> str:
# MSSQL-sourced defaults arrive wrapped — peel only when the outer pair
# encloses the whole expression so we leave ``(1)+(2)`` alone.
expr = expr.strip()
while expr.startswith("(") and expr.endswith(")"):
depth = 0
balanced = True
for index, ch in enumerate(expr):
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth == 0 and index != len(expr) - 1:
balanced = False
break
if not balanced:
break
expr = expr[1:-1].strip()
return expr

def _translate_default(self, raw: str, target_type: str) -> str:
expr = self._strip_wrapping_parens(raw)
expr = self._strip_pg_cast(expr)

# PG / MySQL boolean literals → MSSQL bit literal when target is ``bit``.
if target_type == "bit":
lower = expr.lower()
if lower == "true":
return "1"
if lower == "false":
return "0"
bit_match = _MYSQL_BIT_RE.match(expr)
if bit_match:
return "1" if int(bit_match.group(1), 2) else "0"

match = _FUNCTION_CALL_RE.match(expr)
if match:
fn = match.group(1).lower()
replacement = _DEFAULT_FUNCTION_MAP.get(fn)
if replacement is not None:
return replacement

return expr

# ---- emit -------------------------------------------------------------

def emit_prologue(self, sink: OutputSink) -> None:
Expand Down
86 changes: 84 additions & 2 deletions db2sql/infrastructure/emit/postgres/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,45 @@

from __future__ import annotations

import re
from typing import Any, Dict, Iterable, Mapping, Optional

from db2sql.application.ports import OutputSink
from db2sql.domain.model import Column, Database, Schema, Table
from db2sql.domain.policy import drop_order, normalize_identifier, topological_order

# Source-side scalar functions that have no PG equivalent under the same name.
# Keys are lowercased, parens stripped; values are PG expressions to substitute.
# Matched regardless of whether the source wrote them with empty parens
# (``getdate()``) or as bare keywords (Oracle ``SYSDATE``).
_DEFAULT_FUNCTION_MAP: Dict[str, str] = {
# MSSQL date/time
"getdate": "now()",
"sysdatetime": "LOCALTIMESTAMP",
"getutcdate": "(now() AT TIME ZONE 'utc')",
"sysutcdatetime": "(now() AT TIME ZONE 'utc')",
"sysdatetimeoffset": "now()",
# Oracle date/time (bare keywords, no parens)
"sysdate": "now()",
"systimestamp": "now()",
# MSSQL / Oracle / MySQL uuid generators
"newid": "gen_random_uuid()",
"newsequentialid": "gen_random_uuid()",
"sys_guid": "gen_random_uuid()",
"uuid": "gen_random_uuid()",
# Session info — MSSQL / Oracle
"suser_sname": "CURRENT_USER",
"system_user": "CURRENT_USER",
"user_name": "CURRENT_USER",
"user": "CURRENT_USER",
"db_name": "current_database()",
}

_UNICODE_STRING_RE = re.compile(r"(?i)\bN'")
_FUNCTION_CALL_RE = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)(?:\s*\(\s*\))?\s*$")
# MySQL ``bit`` default literal — ``b'0'`` / ``b'1'``.
_MYSQL_BIT_RE = re.compile(r"^(?i:b)'([01]+)'$")


class PostgresSqlEmitter:
"""Produce PostgreSQL DDL+DML for a collected :class:`Database`."""
Expand Down Expand Up @@ -90,7 +123,8 @@ def _normalize(self, name: str) -> str:

def quote_identifier(self, name: str) -> str:
normalized = self._normalize(name)
return '"{}"'.format(normalized.replace('"', '""'))
escaped = normalized.replace('"', '""')
return f'"{escaped}"'

def schema_name(self, schema: Schema) -> str:
mapped = self._schema_mapping.get(schema.name, schema.name)
Expand Down Expand Up @@ -118,9 +152,57 @@ def column_definition(self, column: Column) -> str:
if not column.nullable and not column.identity:
parts.append("NOT NULL")
if column.default is not None and not column.identity:
parts.append(f"DEFAULT {column.default}")
parts.append(f"DEFAULT {self._translate_default(column.default, target_type)}")
return " ".join(parts)

@staticmethod
def _strip_wrapping_parens(expr: str) -> str:
# MSSQL wraps every default in at least one extra pair of parens
# (``((0))``, ``(getdate())``, ``(N'foo')``). Peel only when the outer
# pair encloses the whole expression — leave ``(1)+(2)`` alone.
expr = expr.strip()
while expr.startswith("(") and expr.endswith(")"):
depth = 0
balanced = True
for index, ch in enumerate(expr):
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth == 0 and index != len(expr) - 1:
balanced = False
break
if not balanced:
break
expr = expr[1:-1].strip()
return expr

def _translate_default(self, raw: str, target_type: str) -> str:
expr = self._strip_wrapping_parens(raw)

# ``N'foo'`` → ``'foo'``. PG has no N-prefixed string literal; strings
# are already unicode-capable.
expr = _UNICODE_STRING_RE.sub("'", expr)

# ``0`` / ``1`` → ``FALSE`` / ``TRUE`` when the column maps to boolean
# (MSSQL ``bit`` becomes PG ``boolean`` and PG won't coerce int→bool
# inside a DEFAULT clause). MySQL ``bit`` defaults arrive as ``b'1'``.
if target_type == "boolean":
if expr in ("0", "1"):
return "FALSE" if expr == "0" else "TRUE"
bit_match = _MYSQL_BIT_RE.match(expr)
if bit_match:
return "FALSE" if int(bit_match.group(1), 2) == 0 else "TRUE"

match = _FUNCTION_CALL_RE.match(expr)
if match:
fn = match.group(1).lower()
replacement = _DEFAULT_FUNCTION_MAP.get(fn)
if replacement is not None:
return replacement

return expr

# ---- emit -------------------------------------------------------------

def emit_prologue(self, sink: OutputSink) -> None:
Expand Down
2 changes: 1 addition & 1 deletion db2sql/infrastructure/logging/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_colorama(stream: object) -> None:
colorama.init()


class Palette:
class Palette: # pylint: disable=too-few-public-methods
"""Wrapper around colorama colors."""

RED = Fore.RED
Expand Down
2 changes: 1 addition & 1 deletion db2sql/infrastructure/logging/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def from_verbosity(
raise InvalidLogLevel(level_name) from exc
stream: IO[str]
if log_file:
stream = open(
stream = open( # pylint: disable=consider-using-with
log_file, "wt", encoding="utf-8"
) # noqa: SIM115 — owned for process lifetime
else:
Expand Down
4 changes: 3 additions & 1 deletion db2sql/infrastructure/output/rotating_file_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def _open_next_part(self) -> None:
parent = self._current_path.parent
if str(parent) and not parent.exists():
parent.mkdir(parents=True, exist_ok=True)
self._current_stream = open(self._current_path, "w", encoding="utf-8")
self._current_stream = open( # pylint: disable=consider-using-with
self._current_path, "w", encoding="utf-8"
)
self._current_size = 0

def _part_path(self, index: int) -> Path:
Expand Down
12 changes: 5 additions & 7 deletions db2sql/infrastructure/persistence/mssql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ def _ensure_session(self) -> Session:
def _connection_string(self) -> str:
server = self._config.server
port = f":{server.port}" if server.port else ""
return "mssql+pymssql://{}:{}@{}{}/{}".format(
server.username or "",
server.password or "",
server.hostname or "",
port,
server.dbname or "",
)
username = server.username or ""
password = server.password or ""
hostname = server.hostname or ""
dbname = server.dbname or ""
return f"mssql+pymssql://{username}:{password}@{hostname}{port}/{dbname}"

def collect_metadata(self) -> Database:
database = Database(str(self._config.server.dbname or ""))
Expand Down
12 changes: 5 additions & 7 deletions db2sql/infrastructure/persistence/mysql/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ def _ensure_session(self) -> Session:
def _connection_string(self) -> str:
server = self._config.server
port = f":{server.port}" if server.port else ""
return "mysql+pymysql://{}:{}@{}{}/{}".format(
server.username or "",
server.password or "",
server.hostname or "",
port,
server.dbname or "",
)
username = server.username or ""
password = server.password or ""
hostname = server.hostname or ""
dbname = server.dbname or ""
return f"mysql+pymysql://{username}:{password}@{hostname}{port}/{dbname}"

@property
def _database_name(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions db2sql/infrastructure/persistence/oracle/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _connection_string(self) -> str:
options = server.options or {}
driver = options.get("driver", "oracledb")
port = f":{server.port}" if server.port else ""
userinfo = "{}:{}".format(server.username or "", server.password or "")
userinfo = f"{server.username or ''}:{server.password or ''}"
host = server.hostname or ""
service_name = options.get("service_name")
sid = options.get("sid")
Expand Down Expand Up @@ -150,7 +150,7 @@ def _read_schemas(self, database: Database) -> None:
owner = self._schema_filter
params: Dict[str, Any] = {}
if owner:
query = "SELECT DISTINCT OWNER FROM ALL_TABLES " "WHERE OWNER = :owner ORDER BY OWNER"
query = "SELECT DISTINCT OWNER FROM ALL_TABLES WHERE OWNER = :owner ORDER BY OWNER"
params["owner"] = owner
else:
query = (
Expand Down Expand Up @@ -342,7 +342,7 @@ def _read_identity_columns(self, database: Database) -> None:
),
params,
)
except Exception:
except Exception: # pylint: disable=broad-exception-caught
return
for row in rows:
table = database.get_table(row.owner, row.table_name)
Expand Down
Loading
Loading