diff --git a/src/diffa/db/connect.py b/src/diffa/db/connect.py index 882f1fd..c0a8843 100644 --- a/src/diffa/db/connect.py +++ b/src/diffa/db/connect.py @@ -38,7 +38,6 @@ def connect(self): password=self.db_config["password"], sslmode="prefer", # Prefer SSL mode ) - self.conn.set_session(autocommit=True) return self.conn def close(self): diff --git a/src/diffa/db/data_models.py b/src/diffa/db/data_models.py index 559954d..2f950ff 100644 --- a/src/diffa/db/data_models.py +++ b/src/diffa/db/data_models.py @@ -178,14 +178,14 @@ def get_dimension_fields(cls) -> List[Tuple[str, type]]: return [(f.name, f.type) for f in fields(cls) if f.name not in base_fields] def get_dimension_values(self): - # check_date is still considered as a dimension field. In fact, it's a main dimension field. + # check_date is the primary dimension field, so it comes first for correct sort order. return { f[0]: getattr(self, f[0]) - for f in self.get_dimension_fields() + [("check_date", date)] + for f in [("check_date", date)] + self.get_dimension_fields() } - def to_flatten_dimension_format(self) -> dict: - return {tuple(self.get_dimension_values().items()): self} + def get_dimension_values_as_tuple(self): + return tuple(self.get_dimension_values().items()) class MergedCountCheck: @@ -204,6 +204,7 @@ def __init__( self.check_date = check_date for key, value in kwargs.items(): setattr(self, key, value) + self._dimensions = sorted(kwargs.keys()) self.is_valid = ( is_valid if is_valid is not None else source_count <= target_count @@ -235,18 +236,6 @@ def __lt__(self, other): def __str__(self): return f"MergedCountCheck({", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())})" - @classmethod - def create_with_dimensions(cls, dimension_fields: List[Tuple[str, type]]): - """Factory method to dynamically create a MergedCountCheck with a CountCheck schema""" - - return type( - cls.__name__, - (cls,), - reduce( - lambda x, y: x | y, map(lambda x: {x[0]: x[1]}, dimension_fields), {} - ), - ) - @classmethod def from_counts( cls, source: Optional[CountCheck] = None, target: Optional[CountCheck] = None @@ -258,6 +247,10 @@ def from_counts( return cls(**merged_count_check_values) + @property + def dimensions(self): + return self._dimensions + def to_diffa_check_schema( self, source_database: str, diff --git a/src/diffa/db/source_target.py b/src/diffa/db/source_target.py index 9c418fe..854c231 100644 --- a/src/diffa/db/source_target.py +++ b/src/diffa/db/source_target.py @@ -22,17 +22,32 @@ def __init__(self, db_config: SourceConfig) -> None: self.conn = PostgresConnection(self.db_config.get_db_config()) def _execute_query(self, query: str, sql_params: tuple = None): + """Execute query eagerly, return a lazy iterator over rows. + Uses a server-side cursor so rows are streamed in batches + rather than fetched entirely into client memory. + """ conn = self.conn.connect() - try: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: - cursor.execute(query, sql_params) + cursor = conn.cursor( + name="diffa_cursor", + cursor_factory=psycopg2.extras.DictCursor, + ) + cursor.itersize = 2000 + cursor.execute(query, sql_params) + + def _iter(): + try: for row in cursor: yield row - except Exception as e: - logger.info("Error encountered. Closing the DB connection...") - conn.close() - raise e + except Exception as e: + logger.info("Error encountered. Closing the DB connection...") + conn.close() + raise e + finally: + cursor.close() + conn.rollback() + + return _iter() def _build_count_query( self, @@ -59,6 +74,11 @@ def _build_count_query( if diff_dimension_cols else "" ) + order_by_diff_dimensions_clause = ( + f", {','.join([f'{col} ASC' for col in diff_dimension_cols])}" + if diff_dimension_cols + else "" + ) return f""" SELECT @@ -72,10 +92,10 @@ def _build_count_query( GROUP BY created_at::DATE {group_by_diff_dimensions_clause} ORDER BY created_at::DATE ASC + {order_by_diff_dimensions_clause} """ def count(self, latest_check_date: date, invalid_check_dates: List[date]): - if self.db_config.get_diff_dimension_cols(): count_query = self._build_count_query( latest_check_date, @@ -103,7 +123,7 @@ def __init__(self, config_manager: ConfigManager): def get_counts( self, last_check_date: date, invalid_check_dates: Iterable[date] - ) -> Iterable[CountCheck]: + ) -> tuple[Iterable[CountCheck], Iterable[CountCheck]]: def to_count_check( count_dict: dict, diff_dimension_cols: Optional[List[str]] = None ) -> CountCheck: @@ -114,6 +134,8 @@ def to_count_check( else: return CountCheck(**count_dict) + # Execute both queries concurrently (DB processing time runs in parallel) + # then stream rows lazily via server-side cursors with ThreadPoolExecutor(max_workers=2) as executor: future_source_count = executor.submit( self.source_db.count, last_check_date, invalid_check_dates @@ -126,6 +148,7 @@ def to_count_check( future_source_count.result(), future_target_count.result(), ) + return map( partial( to_count_check, diff --git a/src/diffa/managers/check_manager.py b/src/diffa/managers/check_manager.py index 59bed91..c8c3c99 100644 --- a/src/diffa/managers/check_manager.py +++ b/src/diffa/managers/check_manager.py @@ -65,71 +65,12 @@ def compare_tables(self): ) ) - # Step 5: Build and log the check summary - self._build_check_summary(merged_count_checks, merged_by_date) - # Return True if there is any invalid diff return self._check_if_valid_diff(merged_by_date.values()) def _check_if_valid_diff(self, merged_by_date: list[MergedCountCheck]) -> bool: return all(mcc.is_valid for mcc in merged_by_date) - def _build_check_summary( - self, - merged_count_checks: Iterable[MergedCountCheck], - merged_by_date: dict[date, MergedCountCheck], - ): - stats_by_day = { - check_date: { - "detailed_msgs": self._get_check_messages( - self._get_checks_by_date(merged_count_checks, check_date) - ), - "summary_msg": self._get_check_messages([mcc])[0], - } - for check_date, mcc in filter( - lambda x: not x[1].is_valid, merged_by_date.items() - ) - } - - summary_lines = [ - f""" - - {check_date}: - summary: - {stats['summary_msg']} - detailed: - {stats['detailed_msgs']} - """ - for check_date, stats in stats_by_day.items() - ] - stats_summary = ( - "\n".join(summary_lines) - if summary_lines - else "No failed days stats available" - ) - - logger.info( - f""" - Data-diff comparison result: - Summary: - - Total days checked: {len(merged_by_date)} - - Stats by failed days: - {stats_summary} - """ - ) - - @staticmethod - def _get_check_messages(merged_count_checks: Iterable[MergedCountCheck]): - return [ - f"{'✅ No Diff' if mcc.is_valid else '❌ Diff'} {mcc}" - for mcc in merged_count_checks - ] - - @staticmethod - def _get_checks_by_date( - merged_count_checks: Iterable[MergedCountCheck], check_date: date - ) -> list[MergedCountCheck]: - return [mcc for mcc in merged_count_checks if mcc.check_date == check_date] - @staticmethod def _merge_by_check_date( merged_count_checks: Iterable[MergedCountCheck], @@ -144,7 +85,17 @@ def _merge_by_check_date( entry["is_valid"] &= mcc.is_valid entry["check_date"] = mcc.check_date - return {cd: MergedCountCheck(**data) for cd, data in merged.items()} + if not mcc.is_valid: + logger.info( + f"❌ Diff found for dimensions {mcc.dimensions}: " + f"Source Count: {mcc.source_count}, " + f"Target Count: {mcc.target_count}, " + f"Is Valid: {mcc.is_valid} " + ) + + return { + check_date: MergedCountCheck(**data) for check_date, data in merged.items() + } def _merge_count_checks( self, source_counts: Iterable[CountCheck], target_counts: Iterable[CountCheck] @@ -156,24 +107,32 @@ def _merge_count_checks( Iterable B: [2,4,5,7] Output [(1,0), (2,2), (0,4), (5,5), (6,0), (0,7)] """ - - source_dict = {} - for count_check in source_counts: - source_dict.update(count_check.to_flatten_dimension_format()) - - target_dict = {} - for count_check in target_counts: - target_dict.update(count_check.to_flatten_dimension_format()) - - all_dims = set(source_dict.keys()) | set(target_dict.keys()) - - merged_count_checks = [] - for dim in all_dims: - source_count = source_dict.get(dim) - target_count = target_dict.get(dim) - merged_count_check = MergedCountCheck.from_counts( - source_count, target_count - ) - merged_count_checks.append(merged_count_check) - - return sorted(merged_count_checks, key=lambda x: x.check_date) + # 2 pointers algorithm to merge the two iterables based on the dimension values (if any) with O(1) space complexity + source_iter = iter(source_counts) + target_iter = iter(target_counts) + source_current = next(source_iter, None) + target_current = next(target_iter, None) + while source_current or target_current: + if source_current and target_current: + if ( + source_current.get_dimension_values_as_tuple() + == target_current.get_dimension_values_as_tuple() + ): + yield MergedCountCheck.from_counts(source_current, target_current) + source_current = next(source_iter, None) + target_current = next(target_iter, None) + elif ( + source_current.get_dimension_values_as_tuple() + < target_current.get_dimension_values_as_tuple() + ): + yield MergedCountCheck.from_counts(source_current, None) + source_current = next(source_iter, None) + else: + yield MergedCountCheck.from_counts(None, target_current) + target_current = next(target_iter, None) + elif source_current: + yield MergedCountCheck.from_counts(source_current, None) + source_current = next(source_iter, None) + elif target_current: + yield MergedCountCheck.from_counts(None, target_current) + target_current = next(target_iter, None) diff --git a/tests/managers/test_check_manager.py b/tests/managers/test_check_manager.py index ea42f36..20170b2 100644 --- a/tests/managers/test_check_manager.py +++ b/tests/managers/test_check_manager.py @@ -105,7 +105,7 @@ def check_manager(): def test__merge_count_check( check_manager, source_counts, target_counts, expected_merged_counts ): - merged_counts = check_manager._merge_count_checks(source_counts, target_counts) + merged_counts = list(check_manager._merge_count_checks(source_counts, target_counts)) assert expected_merged_counts == merged_counts @pytest.mark.parametrize( @@ -130,7 +130,7 @@ def test__merge_count_check( ) ], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=100, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), @@ -142,8 +142,7 @@ def test__merge_count_check( # Case 2: Checking dates are in source only ( [ - CountCheck.create_with_dimensions( - ["status", "country"])( + CountCheck.create_with_dimensions(["status", "country"])( cnt=100, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", @@ -152,7 +151,7 @@ def test__merge_count_check( ], [], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=100, target_count=0, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), @@ -173,7 +172,7 @@ def test__merge_count_check( ) ], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=0, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), @@ -183,20 +182,23 @@ def test__merge_count_check( ], ), # Case 4: Checking different dates in source and target + # Two-pointer merge orders by dimension tuple: (check_date, country, status) + # matching the SQL ORDER BY check_date ASC, country ASC, status ASC ( + # Source inputs sorted as DB would return: check_date ASC, country ASC, status ASC [ CountCheck.create_with_dimensions(["status", "country"])( cnt=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="US" + country="Singapore" ), CountCheck.create_with_dimensions(["status", "country"])( cnt=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="Singapore" - ) + country="US" + ), ], [ CountCheck.create_with_dimensions(["status", "country"])( @@ -207,21 +209,21 @@ def test__merge_count_check( ) ], [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=200, target_count=0, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="US" + country="Singapore" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=200, target_count=0, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", - country="Singapore" + country="US" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=0, target_count=200, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), @@ -236,7 +238,7 @@ def test__merge_count_check( ], ) def test__merge_count_check_with_dimensions(check_manager, source_counts, target_counts, expected_merged_counts): - merged_counts = check_manager._merge_count_checks(source_counts, target_counts) + merged_counts = list(check_manager._merge_count_checks(source_counts, target_counts)) assert expected_merged_counts == merged_counts @@ -275,19 +277,19 @@ def test__merge_count_check_with_dimensions(check_manager, source_counts, target # Case 2: Merge count checks by check date with 1 dimension field (happy case) ( [ - MergedCountCheck.create_with_dimensions(["status"])( + MergedCountCheck( source_count=100, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True" ), - MergedCountCheck.create_with_dimensions(["status"])( + MergedCountCheck( source_count=200, target_count=300, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="False" ), - MergedCountCheck.create_with_dimensions(["status"])( + MergedCountCheck( source_count=400, target_count=300, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(), @@ -312,21 +314,21 @@ def test__merge_count_check_with_dimensions(check_manager, source_counts, target # Case 3: Merge count checks by check date with 2 dimension fields (unhappy case: dimenssion failure => invalid diff) ( [ - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=100, target_count=200, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="True", country="US" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=200, target_count=100, check_date=datetime.strptime("2024-01-01", "%Y-%m-%d").date(), status="False", country="US" ), - MergedCountCheck.create_with_dimensions(["status", "country"])( + MergedCountCheck( source_count=300, target_count=400, check_date=datetime.strptime("2024-01-02", "%Y-%m-%d").date(),