Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/diffa/db/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def connect(self):
password=self.db_config["password"],
sslmode="prefer", # Prefer SSL mode
)
self.conn.set_session(autocommit=True)
return self.conn

def close(self):
Expand Down
25 changes: 9 additions & 16 deletions src/diffa/db/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ def get_dimension_fields(cls) -> List[Tuple[str, type]]:
return [(f.name, f.type) for f in fields(cls) if f.name not in base_fields]

def get_dimension_values(self):
# check_date is still considered as a dimension field. In fact, it's a main dimension field.
# check_date is the primary dimension field, so it comes first for correct sort order.
return {
f[0]: getattr(self, f[0])
for f in self.get_dimension_fields() + [("check_date", date)]
for f in [("check_date", date)] + self.get_dimension_fields()
}

def to_flatten_dimension_format(self) -> dict:
return {tuple(self.get_dimension_values().items()): self}
def get_dimension_values_as_tuple(self):
return tuple(self.get_dimension_values().items())


class MergedCountCheck:
Expand All @@ -204,6 +204,7 @@ def __init__(
self.check_date = check_date
for key, value in kwargs.items():
setattr(self, key, value)
self._dimensions = sorted(kwargs.keys())

self.is_valid = (
is_valid if is_valid is not None else source_count <= target_count
Expand Down Expand Up @@ -235,18 +236,6 @@ def __lt__(self, other):
def __str__(self):
return f"MergedCountCheck({", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())})"

@classmethod
def create_with_dimensions(cls, dimension_fields: List[Tuple[str, type]]):
"""Factory method to dynamically create a MergedCountCheck with a CountCheck schema"""

return type(
cls.__name__,
(cls,),
reduce(
lambda x, y: x | y, map(lambda x: {x[0]: x[1]}, dimension_fields), {}
),
)

@classmethod
def from_counts(
cls, source: Optional[CountCheck] = None, target: Optional[CountCheck] = None
Expand All @@ -258,6 +247,10 @@ def from_counts(

return cls(**merged_count_check_values)

@property
def dimensions(self):
return self._dimensions

def to_diffa_check_schema(
self,
source_database: str,
Expand Down
41 changes: 32 additions & 9 deletions src/diffa/db/source_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,32 @@ def __init__(self, db_config: SourceConfig) -> None:
self.conn = PostgresConnection(self.db_config.get_db_config())

def _execute_query(self, query: str, sql_params: tuple = None):
"""Execute query eagerly, return a lazy iterator over rows.

Uses a server-side cursor so rows are streamed in batches
rather than fetched entirely into client memory.
"""
conn = self.conn.connect()
try:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
cursor.execute(query, sql_params)
cursor = conn.cursor(
name="diffa_cursor",
cursor_factory=psycopg2.extras.DictCursor,
)
cursor.itersize = 2000
cursor.execute(query, sql_params)

def _iter():
try:
for row in cursor:
yield row
except Exception as e:
logger.info("Error encountered. Closing the DB connection...")
conn.close()
raise e
except Exception as e:
logger.info("Error encountered. Closing the DB connection...")
conn.close()
raise e
finally:
cursor.close()
conn.rollback()

return _iter()

def _build_count_query(
self,
Expand All @@ -59,6 +74,11 @@ def _build_count_query(
if diff_dimension_cols
else ""
)
order_by_diff_dimensions_clause = (
f", {','.join([f'{col} ASC' for col in diff_dimension_cols])}"
if diff_dimension_cols
else ""
)

return f"""
SELECT
Expand All @@ -72,10 +92,10 @@ def _build_count_query(
GROUP BY created_at::DATE
{group_by_diff_dimensions_clause}
ORDER BY created_at::DATE ASC
{order_by_diff_dimensions_clause}
"""

def count(self, latest_check_date: date, invalid_check_dates: List[date]):

if self.db_config.get_diff_dimension_cols():
count_query = self._build_count_query(
latest_check_date,
Expand Down Expand Up @@ -103,7 +123,7 @@ def __init__(self, config_manager: ConfigManager):

def get_counts(
self, last_check_date: date, invalid_check_dates: Iterable[date]
) -> Iterable[CountCheck]:
) -> tuple[Iterable[CountCheck], Iterable[CountCheck]]:
def to_count_check(
count_dict: dict, diff_dimension_cols: Optional[List[str]] = None
) -> CountCheck:
Expand All @@ -114,6 +134,8 @@ def to_count_check(
else:
return CountCheck(**count_dict)

# Execute both queries concurrently (DB processing time runs in parallel)
# then stream rows lazily via server-side cursors
with ThreadPoolExecutor(max_workers=2) as executor:
future_source_count = executor.submit(
self.source_db.count, last_check_date, invalid_check_dates
Expand All @@ -126,6 +148,7 @@ def to_count_check(
future_source_count.result(),
future_target_count.result(),
)

return map(
partial(
to_count_check,
Expand Down
121 changes: 40 additions & 81 deletions src/diffa/managers/check_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,71 +65,12 @@ def compare_tables(self):
)
)

# Step 5: Build and log the check summary
self._build_check_summary(merged_count_checks, merged_by_date)

# Return True if there is any invalid diff
return self._check_if_valid_diff(merged_by_date.values())

def _check_if_valid_diff(self, merged_by_date: list[MergedCountCheck]) -> bool:
return all(mcc.is_valid for mcc in merged_by_date)

def _build_check_summary(
self,
merged_count_checks: Iterable[MergedCountCheck],
merged_by_date: dict[date, MergedCountCheck],
):
stats_by_day = {
check_date: {
"detailed_msgs": self._get_check_messages(
self._get_checks_by_date(merged_count_checks, check_date)
),
"summary_msg": self._get_check_messages([mcc])[0],
}
for check_date, mcc in filter(
lambda x: not x[1].is_valid, merged_by_date.items()
)
}

summary_lines = [
f"""
- {check_date}:
summary:
{stats['summary_msg']}
detailed:
{stats['detailed_msgs']}
"""
for check_date, stats in stats_by_day.items()
]
stats_summary = (
"\n".join(summary_lines)
if summary_lines
else "No failed days stats available"
)

logger.info(
f"""
Data-diff comparison result:
Summary:
- Total days checked: {len(merged_by_date)}
- Stats by failed days:
{stats_summary}
"""
)

@staticmethod
def _get_check_messages(merged_count_checks: Iterable[MergedCountCheck]):
return [
f"{'✅ No Diff' if mcc.is_valid else '❌ Diff'} {mcc}"
for mcc in merged_count_checks
]

@staticmethod
def _get_checks_by_date(
merged_count_checks: Iterable[MergedCountCheck], check_date: date
) -> list[MergedCountCheck]:
return [mcc for mcc in merged_count_checks if mcc.check_date == check_date]

@staticmethod
def _merge_by_check_date(
merged_count_checks: Iterable[MergedCountCheck],
Expand All @@ -144,7 +85,17 @@ def _merge_by_check_date(
entry["is_valid"] &= mcc.is_valid
entry["check_date"] = mcc.check_date

return {cd: MergedCountCheck(**data) for cd, data in merged.items()}
if not mcc.is_valid:
logger.info(
f"❌ Diff found for dimensions {mcc.dimensions}: "
f"Source Count: {mcc.source_count}, "
f"Target Count: {mcc.target_count}, "
f"Is Valid: {mcc.is_valid} "
)

return {
check_date: MergedCountCheck(**data) for check_date, data in merged.items()
}

def _merge_count_checks(
self, source_counts: Iterable[CountCheck], target_counts: Iterable[CountCheck]
Expand All @@ -156,24 +107,32 @@ def _merge_count_checks(
Iterable B: [2,4,5,7]
Output [(1,0), (2,2), (0,4), (5,5), (6,0), (0,7)]
"""

source_dict = {}
for count_check in source_counts:
source_dict.update(count_check.to_flatten_dimension_format())

target_dict = {}
for count_check in target_counts:
target_dict.update(count_check.to_flatten_dimension_format())

all_dims = set(source_dict.keys()) | set(target_dict.keys())

merged_count_checks = []
for dim in all_dims:
source_count = source_dict.get(dim)
target_count = target_dict.get(dim)
merged_count_check = MergedCountCheck.from_counts(
source_count, target_count
)
merged_count_checks.append(merged_count_check)

return sorted(merged_count_checks, key=lambda x: x.check_date)
# 2 pointers algorithm to merge the two iterables based on the dimension values (if any) with O(1) space complexity
source_iter = iter(source_counts)
target_iter = iter(target_counts)
source_current = next(source_iter, None)
target_current = next(target_iter, None)
while source_current or target_current:
if source_current and target_current:
if (
source_current.get_dimension_values_as_tuple()
== target_current.get_dimension_values_as_tuple()
):
yield MergedCountCheck.from_counts(source_current, target_current)
source_current = next(source_iter, None)
target_current = next(target_iter, None)
elif (
source_current.get_dimension_values_as_tuple()
< target_current.get_dimension_values_as_tuple()
):
yield MergedCountCheck.from_counts(source_current, None)
source_current = next(source_iter, None)
else:
yield MergedCountCheck.from_counts(None, target_current)
target_current = next(target_iter, None)
elif source_current:
yield MergedCountCheck.from_counts(source_current, None)
source_current = next(source_iter, None)
elif target_current:
yield MergedCountCheck.from_counts(None, target_current)
target_current = next(target_iter, None)
Loading