diff --git a/ingestion/src/metadata/ingestion/source/database/common_db_source.py b/ingestion/src/metadata/ingestion/source/database/common_db_source.py index f031884285dd..394161f811aa 100644 --- a/ingestion/src/metadata/ingestion/source/database/common_db_source.py +++ b/ingestion/src/metadata/ingestion/source/database/common_db_source.py @@ -62,7 +62,6 @@ from metadata.ingestion.models.patch_request import PatchedEntity, PatchRequest from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.ingestion.source.connections import get_connection -from metadata.ingestion.source.connections_utils import kill_active_connections from metadata.ingestion.source.database.database_service import DatabaseServiceSource from metadata.ingestion.source.database.sql_column_handler import SqlColumnHandlerMixin from metadata.ingestion.source.database.sqlalchemy_source import SqlAlchemySource @@ -152,15 +151,40 @@ def set_inspector(self, database_name: str) -> None: :param database_name: new database to set """ - kill_active_connections(self.engine) + self._release_engine() logger.info(f"Ingesting from database: {database_name}") new_service_connection = deepcopy(self.service_connection) new_service_connection.database = database_name self.engine = get_connection(new_service_connection) + self.session = create_and_bind_thread_safe_session(self.engine) - self._connection_map = {} # Lazy init as well + def _release_engine(self) -> None: + # Close fairies first so _ConnectionRecord drops its pool reference; + # dispose alone leaves them orphaned and causes _finalize_fairy + # RecursionErrors at GC time. Clearing _inspector_map is what + # actually frees Inspector.info_cache — dispose() does not. + if getattr(self, "engine", None) is None: + return + for conn in self._connection_map.values(): + try: + conn.close() + except Exception: # pylint: disable=broad-except + logger.debug("Connection already closed", exc_info=True) + self._connection_map = {} self._inspector_map = {} + session = getattr(self, "session", None) + if session is not None: + try: + session.remove() + except Exception: # pylint: disable=broad-except + logger.debug("Session cleanup failed", exc_info=True) + self.session = None + try: + self.engine.dispose() + except Exception as exc: # pylint: disable=broad-except + logger.warning(f"Failed to dispose engine: {exc}") + self.engine = None def get_database_names(self) -> Iterable[str]: """ @@ -780,14 +804,10 @@ def inspector(self) -> Inspector: return self._inspector_map[thread_id] def close(self): - if self.connection is not None: - self.connection.close() - for connection in self._connection_map.values(): - connection.close() + self._release_engine() if hasattr(self, "ssl_manager") and self.ssl_manager: self.ssl_manager = cast(SSLManager, self.ssl_manager) self.ssl_manager.cleanup_temp_files() - self.engine.dispose() def fetch_table_tags( self, diff --git a/ingestion/src/metadata/ingestion/source/database/multi_db_source.py b/ingestion/src/metadata/ingestion/source/database/multi_db_source.py index ad36881f5a4d..5050139caa6a 100644 --- a/ingestion/src/metadata/ingestion/source/database/multi_db_source.py +++ b/ingestion/src/metadata/ingestion/source/database/multi_db_source.py @@ -33,7 +33,9 @@ def get_database_names_raw(self) -> Iterable[str]: """ def _execute_database_query(self, query: str) -> Iterable[str]: - results = self.connection.execute(text(query)) # pylint: disable=no-member + results = self.connection.execute( + text(query) + ).fetchall() # pylint: disable=no-member for res in results: row = list(res) yield row[0] diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py index d3688fd92528..5fbe1745575b 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py @@ -402,7 +402,7 @@ def get_configured_database(self) -> Optional[str]: return self.service_connection.database def get_database_names_raw(self) -> Iterable[str]: - results = self.connection.execute(text(SNOWFLAKE_GET_DATABASES)) + results = self.connection.execute(text(SNOWFLAKE_GET_DATABASES)).fetchall() for res in results: row = list(res) yield row[1] diff --git a/ingestion/tests/unit/topology/database/test_common_db_source.py b/ingestion/tests/unit/topology/database/test_common_db_source.py index 774c35cdfd69..052553f38bf6 100644 --- a/ingestion/tests/unit/topology/database/test_common_db_source.py +++ b/ingestion/tests/unit/topology/database/test_common_db_source.py @@ -13,9 +13,15 @@ Tests for CommonDbSourceService._prepare_foreign_constraints """ +import gc +import weakref from unittest.mock import MagicMock, patch import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine +from sqlalchemy.inspection import inspect +from sqlalchemy.pool import QueuePool from metadata.generated.schema.entity.data.table import ( Column, @@ -24,8 +30,10 @@ Table, TableConstraint, ) +from metadata.ingestion.connections.session import create_and_bind_thread_safe_session from metadata.ingestion.source.database.common_db_source import CommonDbSourceService from metadata.ingestion.source.database.database_service import DatabaseServiceSource +from metadata.ingestion.source.database.multi_db_source import MultiDBSource @pytest.fixture @@ -334,3 +342,228 @@ def test_constraint_with_none_columns_skipped(self): result = DatabaseServiceSource.normalize_table_constraints(constraints, columns) assert result[0].columns is None assert result[1].columns == ["id"] + + +class _ReleaseOnlySurrogate(CommonDbSourceService): + """ + Minimal concrete subclass that bypasses CommonDbSourceService.__init__ + (which needs a full workflow config) so we can drive _release_engine / + close against a real SQLAlchemy engine in isolation. + """ + + def __init__(self, engine=None): # pylint: disable=super-init-not-called + self.engine = engine + self._connection_map = {} + self._inspector_map = {} + self.session = None + self.ssl_manager = None + + def create(self, *args, **kwargs): # satisfy abstract method contract + raise NotImplementedError + + +def _make_release_surrogate(engine=None): + """ + Build a minimal stand-in for CommonDbSourceService that has just the + attributes _release_engine touches, bypassing the heavy __init__ that + requires a full workflow config. + """ + return _ReleaseOnlySurrogate(engine=engine) + + +@pytest.fixture +def sqlite_engine(): + """Real, in-memory SQLite engine with an explicit QueuePool.""" + engine = create_engine("sqlite:///:memory:", poolclass=QueuePool) + yield engine + try: + engine.dispose() + except Exception: + pass + + +@pytest.fixture +def surrogate(sqlite_engine): + """Minimal CommonDbSourceService with a real engine attached.""" + return _make_release_surrogate(sqlite_engine) + + +class TestReleaseEngine: + """Option B: _release_engine closes all pooled connections, clears + inspector/session state, and disposes the engine regardless of + which thread called it.""" + + def test_closes_every_connection_map_entry(self, surrogate): + conn_a = surrogate.engine.connect() + conn_b = surrogate.engine.connect() + surrogate._connection_map[111] = conn_a + surrogate._connection_map[222] = conn_b + + surrogate._release_engine() + + assert conn_a.closed is True + assert conn_b.closed is True + assert surrogate._connection_map == {} + + def test_clears_inspector_map(self, surrogate): + surrogate._connection_map[999] = surrogate.engine.connect() + surrogate._inspector_map[999] = inspect(surrogate._connection_map[999]) + assert len(surrogate._inspector_map) == 1 + + surrogate._release_engine() + + assert surrogate._inspector_map == {} + + def test_disposes_pool_and_clears_engine_ref(self, surrogate): + captured_engine = surrogate.engine + original_pool = captured_engine.pool + assert isinstance(original_pool, QueuePool) + connection = surrogate.engine.connect() + surrogate._connection_map[1] = connection + + surrogate._release_engine() + + assert surrogate.engine is None + assert connection.closed is True + assert original_pool.checkedout() == 0 + + def test_removes_session(self, surrogate): + surrogate.session = create_and_bind_thread_safe_session(surrogate.engine) + assert surrogate.session is not None + + surrogate._release_engine() + + assert surrogate.session is None + + def test_idempotent_when_engine_is_none(self): + surrogate = _make_release_surrogate(engine=None) + surrogate._release_engine() + assert surrogate.engine is None + assert surrogate._connection_map == {} + assert surrogate._inspector_map == {} + + def test_tolerates_already_closed_connection(self, surrogate): + healthy = surrogate.engine.connect() + already_closed = surrogate.engine.connect() + already_closed.close() + surrogate._connection_map[1] = healthy + surrogate._connection_map[2] = already_closed + + surrogate._release_engine() + + assert healthy.closed is True + assert surrogate._connection_map == {} + + def test_closes_connections_from_arbitrary_thread_ids(self, surrogate): + """Key property of Option B: close-all, not detach-current-thread. + Every fairy in _connection_map must close regardless of the caller's + thread id.""" + conns = { + 111: surrogate.engine.connect(), + 222: surrogate.engine.connect(), + 333: surrogate.engine.connect(), + } + surrogate._connection_map.update(conns) + + surrogate._release_engine() + + for conn in conns.values(): + assert conn.closed is True + assert surrogate._connection_map == {} + + +class TestEngineGcReclamation: + """Acceptance test for the memory leak fix: after _release_engine and + dropping the strong reference, the old Engine must be garbage-collectable. + The previous kill_active_connections path left _ConnectionRecord fairies + pinning the engine, which is what this test guards against.""" + + def test_old_engine_becomes_gc_eligible_after_release(self): + engine = create_engine("sqlite:///:memory:", poolclass=QueuePool) + surrogate = _make_release_surrogate(engine) + surrogate._connection_map[12345] = surrogate.engine.connect() + + old_engine_ref = weakref.ref(surrogate.engine) + + surrogate._release_engine() + surrogate.engine = None + engine = None # drop local strong ref too + + gc.collect() + + assert old_engine_ref() is None + + +class _FakeSource(MultiDBSource): + """Minimal MultiDBSource that exposes a real SQLAlchemy connection so we + can exercise _execute_database_query against a live cursor.""" + + def __init__(self, engine: Engine): + self._engine = engine + self._conn = engine.connect() + + @property + def connection(self): + return self._conn + + def close(self): + try: + self._conn.close() + except Exception: + pass + + def get_configured_database(self): + return None + + def get_database_names_raw(self): + return self._execute_database_query("SELECT name FROM dbs ORDER BY id") + + +class TestExecuteDatabaseQueryEagerFetch: + """Option B Part 2: _execute_database_query must eagerly .fetchall() + so that _release_engine closing the connection in _connection_map + (the original regression pattern from set_inspector) does not + invalidate the cursor the generator is iterating.""" + + @pytest.fixture + def seeded_engine(self): + engine = create_engine("sqlite:///:memory:", poolclass=QueuePool) + with engine.connect() as conn: + conn.execute(text("CREATE TABLE dbs (id INTEGER PRIMARY KEY, name TEXT)")) + conn.execute( + text( + "INSERT INTO dbs(id, name) VALUES (1, 'alpha'), (2, 'beta'), (3, 'gamma')" + ) + ) + conn.commit() + yield engine + try: + engine.dispose() + except Exception: + pass + + @pytest.fixture + def fake_source(self, seeded_engine): + source = _FakeSource(seeded_engine) + yield source + source.close() + + def test_generator_survives_connection_close_mid_iteration(self, fake_source): + # Simulates what _release_engine actually does: it close()s every + # connection in _connection_map BEFORE disposing the engine. Without + # .fetchall() the cursor would die at that close() and the next + # yield would raise; with .fetchall() the rows are already buffered. + generator = fake_source.get_database_names_raw() + + first = next(generator) + assert first == "alpha" + + fake_source._conn.close() + + remaining = list(generator) + assert remaining == ["beta", "gamma"] + + def test_returns_all_rows_in_order(self, fake_source): + results = list(fake_source.get_database_names_raw()) + + assert results == ["alpha", "beta", "gamma"] diff --git a/ingestion/tests/unit/topology/database/test_snowflake.py b/ingestion/tests/unit/topology/database/test_snowflake.py index 0e0d7674c917..1728d1b466b0 100644 --- a/ingestion/tests/unit/topology/database/test_snowflake.py +++ b/ingestion/tests/unit/topology/database/test_snowflake.py @@ -853,3 +853,49 @@ def test_empty_tag_value_skipped_with_warning(self): source.schema_tags_map["TEST_SCHEMA"][0], {"tag_name": "TEST_TAG", "tag_value": "123"}, ) + + +class TestSnowflakeGetDatabaseNamesRawEagerFetch: + """ + Option B Part 2 applied to Snowflake: get_database_names_raw must call + .fetchall() so that a subsequent engine.dispose() / set_inspector does + not invalidate the cursor mid-iteration. + """ + + @staticmethod + def _build_mock_rows(): + return [ + ["row_meta", "DB_A"], + ["row_meta", "DB_B"], + ["row_meta", "DB_C"], + ] + + def test_fetchall_invoked_exactly_once(self): + source = SnowflakeSource.__new__(SnowflakeSource) + result = MagicMock() + result.fetchall.return_value = self._build_mock_rows() + mock_conn = MagicMock() + mock_conn.execute.return_value = result + + with patch.object( + SnowflakeSource, "connection", new_callable=PropertyMock + ) as mocked_conn_prop: + mocked_conn_prop.return_value = mock_conn + list(source.get_database_names_raw()) + + assert result.fetchall.call_count == 1 + + def test_yields_database_names_in_order(self): + source = SnowflakeSource.__new__(SnowflakeSource) + result = MagicMock() + result.fetchall.return_value = self._build_mock_rows() + mock_conn = MagicMock() + mock_conn.execute.return_value = result + + with patch.object( + SnowflakeSource, "connection", new_callable=PropertyMock + ) as mocked_conn_prop: + mocked_conn_prop.return_value = mock_conn + names = list(source.get_database_names_raw()) + + assert names == ["DB_A", "DB_B", "DB_C"]