Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from metadata.ingestion.models.patch_request import PatchedEntity, PatchRequest
from metadata.ingestion.ometa.ometa_api import OpenMetadata
from metadata.ingestion.source.connections import get_connection
from metadata.ingestion.source.connections_utils import kill_active_connections
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.sql_column_handler import SqlColumnHandlerMixin
from metadata.ingestion.source.database.sqlalchemy_source import SqlAlchemySource
Expand Down Expand Up @@ -152,15 +151,40 @@ def set_inspector(self, database_name: str) -> None:
:param database_name: new database to set
"""

kill_active_connections(self.engine)
self._release_engine()
logger.info(f"Ingesting from database: {database_name}")

new_service_connection = deepcopy(self.service_connection)
new_service_connection.database = database_name
self.engine = get_connection(new_service_connection)
self.session = create_and_bind_thread_safe_session(self.engine)

self._connection_map = {} # Lazy init as well
def _release_engine(self) -> None:
# Close fairies first so _ConnectionRecord drops its pool reference;
# dispose alone leaves them orphaned and causes _finalize_fairy
# RecursionErrors at GC time. Clearing _inspector_map is what
# actually frees Inspector.info_cache — dispose() does not.
if getattr(self, "engine", None) is None:
return
for conn in self._connection_map.values():
try:
conn.close()
except Exception: # pylint: disable=broad-except
logger.debug("Connection already closed", exc_info=True)
self._connection_map = {}
self._inspector_map = {}
session = getattr(self, "session", None)
Comment thread
gitar-bot[bot] marked this conversation as resolved.
if session is not None:
try:
session.remove()
except Exception: # pylint: disable=broad-except
logger.debug("Session cleanup failed", exc_info=True)
self.session = None
try:
self.engine.dispose()
except Exception as exc: # pylint: disable=broad-except
logger.warning(f"Failed to dispose engine: {exc}")
self.engine = None

def get_database_names(self) -> Iterable[str]:
"""
Expand Down Expand Up @@ -780,14 +804,10 @@ def inspector(self) -> Inspector:
return self._inspector_map[thread_id]

def close(self):
if self.connection is not None:
self.connection.close()
for connection in self._connection_map.values():
connection.close()
self._release_engine()
if hasattr(self, "ssl_manager") and self.ssl_manager:
self.ssl_manager = cast(SSLManager, self.ssl_manager)
self.ssl_manager.cleanup_temp_files()
self.engine.dispose()

def fetch_table_tags(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def get_database_names_raw(self) -> Iterable[str]:
"""

def _execute_database_query(self, query: str) -> Iterable[str]:
results = self.connection.execute(text(query)) # pylint: disable=no-member
results = self.connection.execute(
text(query)
).fetchall() # pylint: disable=no-member
for res in results:
row = list(res)
yield row[0]
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def get_configured_database(self) -> Optional[str]:
return self.service_connection.database

def get_database_names_raw(self) -> Iterable[str]:
results = self.connection.execute(text(SNOWFLAKE_GET_DATABASES))
results = self.connection.execute(text(SNOWFLAKE_GET_DATABASES)).fetchall()
for res in results:
row = list(res)
yield row[1]
Expand Down
233 changes: 233 additions & 0 deletions ingestion/tests/unit/topology/database/test_common_db_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
Tests for CommonDbSourceService._prepare_foreign_constraints
"""

import gc
import weakref
from unittest.mock import MagicMock, patch

import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect
from sqlalchemy.pool import QueuePool

from metadata.generated.schema.entity.data.table import (
Column,
Expand All @@ -24,8 +30,10 @@
Table,
TableConstraint,
)
from metadata.ingestion.connections.session import create_and_bind_thread_safe_session
from metadata.ingestion.source.database.common_db_source import CommonDbSourceService
from metadata.ingestion.source.database.database_service import DatabaseServiceSource
from metadata.ingestion.source.database.multi_db_source import MultiDBSource


@pytest.fixture
Expand Down Expand Up @@ -334,3 +342,228 @@ def test_constraint_with_none_columns_skipped(self):
result = DatabaseServiceSource.normalize_table_constraints(constraints, columns)
assert result[0].columns is None
assert result[1].columns == ["id"]


class _ReleaseOnlySurrogate(CommonDbSourceService):
"""
Minimal concrete subclass that bypasses CommonDbSourceService.__init__
(which needs a full workflow config) so we can drive _release_engine /
close against a real SQLAlchemy engine in isolation.
"""

def __init__(self, engine=None): # pylint: disable=super-init-not-called
self.engine = engine
self._connection_map = {}
self._inspector_map = {}
self.session = None
self.ssl_manager = None

def create(self, *args, **kwargs): # satisfy abstract method contract
raise NotImplementedError


def _make_release_surrogate(engine=None):
"""
Build a minimal stand-in for CommonDbSourceService that has just the
attributes _release_engine touches, bypassing the heavy __init__ that
requires a full workflow config.
"""
return _ReleaseOnlySurrogate(engine=engine)


@pytest.fixture
def sqlite_engine():
"""Real, in-memory SQLite engine with an explicit QueuePool."""
engine = create_engine("sqlite:///:memory:", poolclass=QueuePool)
yield engine
try:
engine.dispose()
except Exception:
pass


@pytest.fixture
def surrogate(sqlite_engine):
"""Minimal CommonDbSourceService with a real engine attached."""
return _make_release_surrogate(sqlite_engine)


class TestReleaseEngine:
"""Option B: _release_engine closes all pooled connections, clears
inspector/session state, and disposes the engine regardless of
which thread called it."""

def test_closes_every_connection_map_entry(self, surrogate):
conn_a = surrogate.engine.connect()
conn_b = surrogate.engine.connect()
surrogate._connection_map[111] = conn_a
surrogate._connection_map[222] = conn_b

surrogate._release_engine()

assert conn_a.closed is True
assert conn_b.closed is True
assert surrogate._connection_map == {}

def test_clears_inspector_map(self, surrogate):
surrogate._connection_map[999] = surrogate.engine.connect()
surrogate._inspector_map[999] = inspect(surrogate._connection_map[999])
assert len(surrogate._inspector_map) == 1

surrogate._release_engine()

assert surrogate._inspector_map == {}

def test_disposes_pool_and_clears_engine_ref(self, surrogate):
captured_engine = surrogate.engine
original_pool = captured_engine.pool
assert isinstance(original_pool, QueuePool)
connection = surrogate.engine.connect()
surrogate._connection_map[1] = connection

surrogate._release_engine()

assert surrogate.engine is None
assert connection.closed is True
assert original_pool.checkedout() == 0

def test_removes_session(self, surrogate):
surrogate.session = create_and_bind_thread_safe_session(surrogate.engine)
assert surrogate.session is not None

surrogate._release_engine()

assert surrogate.session is None

def test_idempotent_when_engine_is_none(self):
surrogate = _make_release_surrogate(engine=None)
surrogate._release_engine()
assert surrogate.engine is None
assert surrogate._connection_map == {}
assert surrogate._inspector_map == {}

def test_tolerates_already_closed_connection(self, surrogate):
healthy = surrogate.engine.connect()
already_closed = surrogate.engine.connect()
already_closed.close()
surrogate._connection_map[1] = healthy
surrogate._connection_map[2] = already_closed

surrogate._release_engine()

assert healthy.closed is True
assert surrogate._connection_map == {}

def test_closes_connections_from_arbitrary_thread_ids(self, surrogate):
"""Key property of Option B: close-all, not detach-current-thread.
Every fairy in _connection_map must close regardless of the caller's
thread id."""
conns = {
111: surrogate.engine.connect(),
222: surrogate.engine.connect(),
333: surrogate.engine.connect(),
}
surrogate._connection_map.update(conns)

surrogate._release_engine()

for conn in conns.values():
assert conn.closed is True
assert surrogate._connection_map == {}


class TestEngineGcReclamation:
"""Acceptance test for the memory leak fix: after _release_engine and
dropping the strong reference, the old Engine must be garbage-collectable.
The previous kill_active_connections path left _ConnectionRecord fairies
pinning the engine, which is what this test guards against."""

def test_old_engine_becomes_gc_eligible_after_release(self):
engine = create_engine("sqlite:///:memory:", poolclass=QueuePool)
surrogate = _make_release_surrogate(engine)
surrogate._connection_map[12345] = surrogate.engine.connect()

old_engine_ref = weakref.ref(surrogate.engine)

surrogate._release_engine()
surrogate.engine = None
engine = None # drop local strong ref too

gc.collect()

assert old_engine_ref() is None


class _FakeSource(MultiDBSource):
"""Minimal MultiDBSource that exposes a real SQLAlchemy connection so we
can exercise _execute_database_query against a live cursor."""

def __init__(self, engine: Engine):
self._engine = engine
self._conn = engine.connect()

@property
def connection(self):
return self._conn

Comment thread
ulixius9 marked this conversation as resolved.
def close(self):
try:
self._conn.close()
except Exception:
pass

def get_configured_database(self):
return None

def get_database_names_raw(self):
return self._execute_database_query("SELECT name FROM dbs ORDER BY id")

Comment thread
ulixius9 marked this conversation as resolved.

class TestExecuteDatabaseQueryEagerFetch:
"""Option B Part 2: _execute_database_query must eagerly .fetchall()
so that _release_engine closing the connection in _connection_map
(the original regression pattern from set_inspector) does not
invalidate the cursor the generator is iterating."""

@pytest.fixture
def seeded_engine(self):
engine = create_engine("sqlite:///:memory:", poolclass=QueuePool)
with engine.connect() as conn:
conn.execute(text("CREATE TABLE dbs (id INTEGER PRIMARY KEY, name TEXT)"))
conn.execute(
text(
"INSERT INTO dbs(id, name) VALUES (1, 'alpha'), (2, 'beta'), (3, 'gamma')"
)
)
conn.commit()
yield engine
try:
engine.dispose()
except Exception:
pass

@pytest.fixture
def fake_source(self, seeded_engine):
source = _FakeSource(seeded_engine)
yield source
source.close()

def test_generator_survives_connection_close_mid_iteration(self, fake_source):
# Simulates what _release_engine actually does: it close()s every
# connection in _connection_map BEFORE disposing the engine. Without
# .fetchall() the cursor would die at that close() and the next
# yield would raise; with .fetchall() the rows are already buffered.
generator = fake_source.get_database_names_raw()

first = next(generator)
assert first == "alpha"

fake_source._conn.close()

remaining = list(generator)
assert remaining == ["beta", "gamma"]

def test_returns_all_rows_in_order(self, fake_source):
results = list(fake_source.get_database_names_raw())

assert results == ["alpha", "beta", "gamma"]
46 changes: 46 additions & 0 deletions ingestion/tests/unit/topology/database/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,49 @@ def test_empty_tag_value_skipped_with_warning(self):
source.schema_tags_map["TEST_SCHEMA"][0],
{"tag_name": "TEST_TAG", "tag_value": "123"},
)


class TestSnowflakeGetDatabaseNamesRawEagerFetch:
"""
Option B Part 2 applied to Snowflake: get_database_names_raw must call
.fetchall() so that a subsequent engine.dispose() / set_inspector does
not invalidate the cursor mid-iteration.
"""

@staticmethod
def _build_mock_rows():
return [
["row_meta", "DB_A"],
["row_meta", "DB_B"],
["row_meta", "DB_C"],
]

def test_fetchall_invoked_exactly_once(self):
source = SnowflakeSource.__new__(SnowflakeSource)
result = MagicMock()
result.fetchall.return_value = self._build_mock_rows()
mock_conn = MagicMock()
mock_conn.execute.return_value = result

with patch.object(
SnowflakeSource, "connection", new_callable=PropertyMock
) as mocked_conn_prop:
mocked_conn_prop.return_value = mock_conn
list(source.get_database_names_raw())

assert result.fetchall.call_count == 1

def test_yields_database_names_in_order(self):
source = SnowflakeSource.__new__(SnowflakeSource)
result = MagicMock()
result.fetchall.return_value = self._build_mock_rows()
mock_conn = MagicMock()
mock_conn.execute.return_value = result

with patch.object(
SnowflakeSource, "connection", new_callable=PropertyMock
) as mocked_conn_prop:
mocked_conn_prop.return_value = mock_conn
names = list(source.get_database_names_raw())

assert names == ["DB_A", "DB_B", "DB_C"]
Loading