From 3b894c92ef215738e9e64703aa790fb778abe516 Mon Sep 17 00:00:00 2001 From: Quoc Nguyen Date: Mon, 6 Apr 2026 00:58:41 +0700 Subject: [PATCH 1/5] fix: resolve OOM in high-cardinality dimension merging - Replace O(n) dict-based merge with O(1) two-pointer generator - Use server-side cursors with generator factory pattern for true DB streaming - Fix check_date sort priority in dimension tuple ordering - Fix MergedCountCheck dimensions property and equality check - Inline diff logging in _merge_by_check_date, remove _build_check_summary Co-Authored-By: Claude Opus 4.6 (1M context) --- src/diffa/db/data_models.py | 30 +++---- src/diffa/db/source_target.py | 50 +++++++---- src/diffa/managers/check_manager.py | 121 +++++++++------------------ tests/managers/test_check_manager.py | 44 +++++----- 4 files changed, 111 insertions(+), 134 deletions(-) diff --git a/src/diffa/db/data_models.py b/src/diffa/db/data_models.py index 559954d..0e89727 100644 --- a/src/diffa/db/data_models.py +++ b/src/diffa/db/data_models.py @@ -178,14 +178,14 @@ def get_dimension_fields(cls) -> List[Tuple[str, type]]: return [(f.name, f.type) for f in fields(cls) if f.name not in base_fields] def get_dimension_values(self): - # check_date is still considered as a dimension field. In fact, it's a main dimension field. + # check_date is the primary dimension field, so it comes first for correct sort order. return { f[0]: getattr(self, f[0]) - for f in self.get_dimension_fields() + [("check_date", date)] + for f in [("check_date", date)] + self.get_dimension_fields() } - def to_flatten_dimension_format(self) -> dict: - return {tuple(self.get_dimension_values().items()): self} + def get_dimension_values_as_tuple(self): + return tuple(self.get_dimension_values().items()) class MergedCountCheck: @@ -204,6 +204,7 @@ def __init__( self.check_date = check_date for key, value in kwargs.items(): setattr(self, key, value) + self._dimensions = list(kwargs.keys()) self.is_valid = ( is_valid if is_valid is not None else source_count <= target_count @@ -212,7 +213,10 @@ def __init__( def __eq__(self, other): if not isinstance(other, MergedCountCheck): return NotImplemented - return self.__dict__ == other.__dict__ + excluded = {"_dimensions"} + return {k: v for k, v in self.__dict__.items() if k not in excluded} == { + k: v for k, v in other.__dict__.items() if k not in excluded + } def __lt__(self, other): if not isinstance(other, MergedCountCheck): @@ -235,18 +239,6 @@ def __lt__(self, other): def __str__(self): return f"MergedCountCheck({", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())})" - @classmethod - def create_with_dimensions(cls, dimension_fields: List[Tuple[str, type]]): - """Factory method to dynamically create a MergedCountCheck with a CountCheck schema""" - - return type( - cls.__name__, - (cls,), - reduce( - lambda x, y: x | y, map(lambda x: {x[0]: x[1]}, dimension_fields), {} - ), - ) - @classmethod def from_counts( cls, source: Optional[CountCheck] = None, target: Optional[CountCheck] = None @@ -258,6 +250,10 @@ def from_counts( return cls(**merged_count_check_values) + @property + def dimensions(self): + return self._dimensions + def to_diffa_check_schema( self, source_database: str, diff --git a/src/diffa/db/source_target.py b/src/diffa/db/source_target.py index 9c418fe..e6f8acc 100644 --- a/src/diffa/db/source_target.py +++ b/src/diffa/db/source_target.py @@ -22,17 +22,31 @@ def __init__(self, db_config: SourceConfig) -> None: self.conn = PostgresConnection(self.db_config.get_db_config()) def _execute_query(self, query: str, sql_params: tuple = None): + """Execute query eagerly, return a lazy iterator over rows. + Uses a server-side cursor so rows are streamed in batches + rather than fetched entirely into client memory. + """ conn = self.conn.connect() - try: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: - cursor.execute(query, sql_params) + cursor = conn.cursor( + name="diffa_cursor", + cursor_factory=psycopg2.extras.DictCursor, + ) + cursor.itersize = 2000 + cursor.execute(query, sql_params) + + def _iter(): + try: for row in cursor: yield row - except Exception as e: - logger.info("Error encountered. Closing the DB connection...") - conn.close() - raise e + except Exception as e: + logger.info("Error encountered. Closing the DB connection...") + conn.close() + raise e + finally: + cursor.close() + + return _iter() def _build_count_query( self, @@ -59,6 +73,11 @@ def _build_count_query( if diff_dimension_cols else "" ) + order_by_diff_dimensions_clause = ( + f", {','.join([f'{col} ASC' for col in diff_dimension_cols])}" + if diff_dimension_cols + else "" + ) return f""" SELECT @@ -72,10 +91,10 @@ def _build_count_query( GROUP BY created_at::DATE {group_by_diff_dimensions_clause} ORDER BY created_at::DATE ASC + {order_by_diff_dimensions_clause} """ def count(self, latest_check_date: date, invalid_check_dates: List[date]): - if self.db_config.get_diff_dimension_cols(): count_query = self._build_count_query( latest_check_date, @@ -103,7 +122,7 @@ def __init__(self, config_manager: ConfigManager): def get_counts( self, last_check_date: date, invalid_check_dates: Iterable[date] - ) -> Iterable[CountCheck]: + ) -> tuple[Iterable[CountCheck], Iterable[CountCheck]]: def to_count_check( count_dict: dict, diff_dimension_cols: Optional[List[str]] = None ) -> CountCheck: @@ -114,18 +133,19 @@ def to_count_check( else: return CountCheck(**count_dict) + # Execute both queries concurrently (DB processing time runs in parallel) + # then stream rows lazily via server-side cursors with ThreadPoolExecutor(max_workers=2) as executor: - future_source_count = executor.submit( + future_source = executor.submit( self.source_db.count, last_check_date, invalid_check_dates ) - future_target_count = executor.submit( + future_target = executor.submit( self.target_db.count, last_check_date, invalid_check_dates ) - source_counts, target_counts = ( - future_source_count.result(), - future_target_count.result(), - ) + source_counts = future_source.result() + target_counts = future_target.result() + return map( partial( to_count_check, diff --git a/src/diffa/managers/check_manager.py b/src/diffa/managers/check_manager.py index 59bed91..c8c3c99 100644 --- a/src/diffa/managers/check_manager.py +++ b/src/diffa/managers/check_manager.py @@ -65,71 +65,12 @@ def compare_tables(self): ) ) - # Step 5: Build and log the check summary - self._build_check_summary(merged_count_checks, merged_by_date) - # Return True if there is any invalid diff return self._check_if_valid_diff(merged_by_date.values()) def _check_if_valid_diff(self, merged_by_date: list[MergedCountCheck]) -> bool: return all(mcc.is_valid for mcc in merged_by_date) - def _build_check_summary( - self, - merged_count_checks: Iterable[MergedCountCheck], - merged_by_date: dict[date, MergedCountCheck], - ): - stats_by_day = { - check_date: { - "detailed_msgs": self._get_check_messages( - self._get_checks_by_date(merged_count_checks, check_date) - ), - "summary_msg": self._get_check_messages([mcc])[0], - } - for check_date, mcc in filter( - lambda x: not x[1].is_valid, merged_by_date.items() - ) - } - - summary_lines = [ - f""" - - {check_date}: - summary: - {stats['summary_msg']} - detailed: - {stats['detailed_msgs']} - """ - for check_date, stats in stats_by_day.items() - ] - stats_summary = ( - "\n".join(summary_lines) - if summary_lines - else "No failed days stats available" - ) - - logger.info( - f""" - Data-diff comparison result: - Summary: - - Total days checked: {len(merged_by_date)} - - Stats by failed days: - {stats_summary} - """ - ) - - @staticmethod - def _get_check_messages(merged_count_checks: Iterable[MergedCountCheck]): - return [ - f"{'✅ No Diff' if mcc.is_valid else '❌ Diff'} {mcc}" - for mcc in merged_count_checks - ] - - @staticmethod - def _get_checks_by_date( - merged_count_checks: Iterable[MergedCountCheck], check_date: date - ) -> list[MergedCountCheck]: - return [mcc for mcc in merged_count_checks if mcc.check_date == check_date] - @staticmethod def _merge_by_check_date( merged_count_checks: Iterable[MergedCountCheck], @@ -144,7 +85,17 @@ def _merge_by_check_date( entry["is_valid"] &= mcc.is_valid entry["check_date"] = mcc.check_date - return {cd: MergedCountCheck(**data) for cd, data in merged.items()} + if not mcc.is_valid: + logger.info( + f"❌ Diff found for dimensions {mcc.dimensions}: " + f"Source Count: {mcc.source_count}, " + f"Target Count: {mcc.target_count}, " + f"Is Valid: {mcc.is_valid} " + ) + + return { + check_date: MergedCountCheck(**data) for check_date, data in merged.items() + } def _merge_count_checks( self, source_counts: Iterable[CountCheck], target_counts: Iterable[CountCheck] @@ -156,24 +107,32 @@ def _merge_count_checks( Iterable B: [2,4,5,7] Output [(1,0), (2,2), (0,4), (5,5), (6,0), (0,7)] """ - - source_dict = {} - for count_check in source_counts: - source_dict.update(count_check.to_flatten_dimension_format()) - - target_dict = {} - for count_check in target_counts: - target_dict.update(count_check.to_flatten_dimension_format()) - - all_dims = set(source_dict.keys()) | set(target_dict.keys()) - - merged_count_checks = [] - for dim in all_dims: - source_count = source_dict.get(dim) - target_count = target_dict.get(dim) - merged_count_check = MergedCountCheck.from_counts( - source_count, target_count - ) - merged_count_checks.append(merged_count_check) - - return sorted(merged_count_checks, key=lambda x: x.check_date) + # 2 pointers algorithm to merge the two iterables based on the dimension values (if any) with O(1) space complexity + source_iter = iter(source_counts) + target_iter = iter(target_counts) + source_current = next(source_iter, None) + target_current = next(target_iter, None) + while source_current or target_current: + if source_current and target_current: + if ( + source_current.get_dimension_values_as_tuple() + == target_current.get_dimension_values_as_tuple() + ): + yield MergedCountCheck.from_counts(source_current, target_current) + source_current = next(source_iter, None) + target_current = next(target_iter, None) + elif ( + source_current.get_dimension_values_as_tuple() + < target_current.get_dimension_values_as_tuple() + ): + yield MergedCountCheck.from_counts(source_current, None) + source_current = next(source_iter, None) + else: + yield MergedCountCheck.from_counts(None, target_current) + target_current = next(target_iter, None) + elif source_current: + yield MergedCountCheck.from_counts(source_current, None) + source_current = next(source_iter, None) + elif target_current: + yield MergedCountCheck.from_counts(None, target_current) + target_current = next(target_iter, None) diff --git a/tests/managers/test_check_manager.py b/tests/managers/test_check_manager.py index ea42f36..20170b2 100644 --- a/tests/managers/test_check_manager.py +++ b/tests/managers/test_check_manager.py @@ -105,7 +105,7 @@ def check_manager(): def test__merge_count_check( check_manager, source_counts, target_counts, expected_merged_counts ): - merged_counts = check_manager._merge_count_checks(source_counts, target_counts) + merged_counts = list(check_manager._merge_count_checks(source_counts, target_counts)) assert expected_merged_counts == merged_counts @pytest.mark.parametrize( @@ -130,7 +130,7 @@ def test__merge_count_check( ) ], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=100, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), @@ -142,8 +142,7 @@ def test__merge_count_check( # Case 2: Checking dates are in source only ( [ - CountCheck.create_with_dimensions( - ["status", "country"])( + CountCheck.create_with_dimensions(["status", "country"])( cnt=100, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", @@ -152,7 +151,7 @@ def test__merge_count_check( ], [], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=100, target_count=0, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), @@ -173,7 +172,7 @@ def test__merge_count_check( ) ], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=0, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), @@ -183,20 +182,23 @@ def test__merge_count_check( ], ), # Case 4: Checking different dates in source and target + # Two-pointer merge orders by dimension tuple: (check_date, country, status) + # matching the SQL ORDER BY check_date ASC, country ASC, status ASC ( + # Source inputs sorted as DB would return: check_date ASC, country ASC, status ASC [ CountCheck.create_with_dimensions(["status", "country"])( cnt=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="US" + country="Singapore" ), CountCheck.create_with_dimensions(["status", "country"])( cnt=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="Singapore" - ) + country="US" + ), ], [ CountCheck.create_with_dimensions(["status", "country"])( @@ -207,21 +209,21 @@ def test__merge_count_check( ) ], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=200, target_count=0, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="US" + country="Singapore" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=200, target_count=0, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="Singapore" + country="US" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=0, target_count=200, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), @@ -236,7 +238,7 @@ def test__merge_count_check( ], ) def test__merge_count_check_with_dimensions(check_manager, source_counts, target_counts, expected_merged_counts): - merged_counts = check_manager._merge_count_checks(source_counts, target_counts) + merged_counts = list(check_manager._merge_count_checks(source_counts, target_counts)) assert expected_merged_counts == merged_counts @@ -275,19 +277,19 @@ def test__merge_count_check_with_dimensions(check_manager, source_counts, target # Case 2: Merge count checks by check date with 1 dimension field (happy case) ( [ - MergedCountCheck.create_with_dimensions(["status"])( + MergedCountCheck( source_count=100, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True" ), - MergedCountCheck.create_with_dimensions(["status"])( + MergedCountCheck( source_count=200, target_count=300, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="False" ), - MergedCountCheck.create_with_dimensions(["status"])( + MergedCountCheck( source_count=400, target_count=300, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), @@ -312,21 +314,21 @@ def test__merge_count_check_with_dimensions(check_manager, source_counts, target # Case 3: Merge count checks by check date with 2 dimension fields (unhappy case: dimenssion failure => invalid diff) ( [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=100, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", country="US" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=200, target_count=100, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="False", country="US" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=300, target_count=400, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), From 807d8b7ee03ae9d717c415b0dff866a3bc5435b5 Mon Sep 17 00:00:00 2001 From: Quoc Nguyen Date: Mon, 6 Apr 2026 01:02:06 +0700 Subject: [PATCH 2/5] fix: sort _dimensions for consistent equality instead of excluding from __eq__ Co-Authored-By: Claude Opus 4.6 (1M context) --- src/diffa/db/data_models.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffa/db/data_models.py b/src/diffa/db/data_models.py index 0e89727..2f950ff 100644 --- a/src/diffa/db/data_models.py +++ b/src/diffa/db/data_models.py @@ -204,7 +204,7 @@ def __init__( self.check_date = check_date for key, value in kwargs.items(): setattr(self, key, value) - self._dimensions = list(kwargs.keys()) + self._dimensions = sorted(kwargs.keys()) self.is_valid = ( is_valid if is_valid is not None else source_count <= target_count @@ -213,10 +213,7 @@ def __init__( def __eq__(self, other): if not isinstance(other, MergedCountCheck): return NotImplemented - excluded = {"_dimensions"} - return {k: v for k, v in self.__dict__.items() if k not in excluded} == { - k: v for k, v in other.__dict__.items() if k not in excluded - } + return self.__dict__ == other.__dict__ def __lt__(self, other): if not isinstance(other, MergedCountCheck): From 28a40964f4a476e5a99e20f1e30ae21615ec1385 Mon Sep 17 00:00:00 2001 From: Quoc Nguyen Date: Mon, 6 Apr 2026 01:05:41 +0700 Subject: [PATCH 3/5] refactor: retain original variable names in get_counts Co-Authored-By: Claude Opus 4.6 (1M context) --- src/diffa/db/source_target.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffa/db/source_target.py b/src/diffa/db/source_target.py index e6f8acc..7ef30f1 100644 --- a/src/diffa/db/source_target.py +++ b/src/diffa/db/source_target.py @@ -133,18 +133,18 @@ def to_count_check( else: return CountCheck(**count_dict) - # Execute both queries concurrently (DB processing time runs in parallel) - # then stream rows lazily via server-side cursors with ThreadPoolExecutor(max_workers=2) as executor: - future_source = executor.submit( + future_source_count = executor.submit( self.source_db.count, last_check_date, invalid_check_dates ) - future_target = executor.submit( + future_target_count = executor.submit( self.target_db.count, last_check_date, invalid_check_dates ) - source_counts = future_source.result() - target_counts = future_target.result() + source_counts, target_counts = ( + future_source_count.result(), + future_target_count.result(), + ) return map( partial( From ab690d25a1d1efb955e4540641f4f7b14f3d2bd3 Mon Sep 17 00:00:00 2001 From: Quoc Nguyen Date: Mon, 6 Apr 2026 01:06:35 +0700 Subject: [PATCH 4/5] docs: re-add comment explaining concurrent query execution Co-Authored-By: Claude Opus 4.6 (1M context) --- src/diffa/db/source_target.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffa/db/source_target.py b/src/diffa/db/source_target.py index 7ef30f1..bba2c02 100644 --- a/src/diffa/db/source_target.py +++ b/src/diffa/db/source_target.py @@ -133,6 +133,8 @@ def to_count_check( else: return CountCheck(**count_dict) + # Execute both queries concurrently (DB processing time runs in parallel) + # then stream rows lazily via server-side cursors with ThreadPoolExecutor(max_workers=2) as executor: future_source_count = executor.submit( self.source_db.count, last_check_date, invalid_check_dates From 035dedd4b9ecca27b4d1657aadd53699001f332d Mon Sep 17 00:00:00 2001 From: Quoc Nguyen Date: Mon, 18 May 2026 10:20:05 +0700 Subject: [PATCH 5/5] removing the implicit transaction --- src/diffa/db/connect.py | 1 - src/diffa/db/source_target.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffa/db/connect.py b/src/diffa/db/connect.py index 882f1fd..c0a8843 100644 --- a/src/diffa/db/connect.py +++ b/src/diffa/db/connect.py @@ -38,7 +38,6 @@ def connect(self): password=self.db_config["password"], sslmode="prefer", # Prefer SSL mode ) - self.conn.set_session(autocommit=True) return self.conn def close(self): diff --git a/src/diffa/db/source_target.py b/src/diffa/db/source_target.py index bba2c02..854c231 100644 --- a/src/diffa/db/source_target.py +++ b/src/diffa/db/source_target.py @@ -45,6 +45,7 @@ def _iter(): raise e finally: cursor.close() + conn.rollback() return _iter()