Skip to content

Commit db8d328

Browse files
JossGeekCopilot
andcommitted
refactor: add return types to helper functions in Apriori algorithm
Co-authored-by: Copilot <copilot@github.com>
2 parents e4d1ba0 + 7bb97a8 commit db8d328

1 file changed

Lines changed: 10 additions & 17 deletions

File tree

machine_learning/apriori_algorithm.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ def load_data() -> list[list[str]]:
2727

2828
# ---------- Helpers ----------
2929

30-
def get_support(itemset: frozenset, transactions: list[set]):
30+
31+
def get_support(itemset: frozenset, transactions: list[set]) -> int:
3132
"""Compute support count of an itemset efficiently."""
3233
return sum(1 for t in transactions if itemset.issubset(t))
3334

3435

35-
def generate_candidates(prev_frequent: set[frozenset], k: int):
36+
def generate_candidates(prev_frequent: set[frozenset], k: int) -> set[frozenset]:
3637
"""
3738
Generate candidate itemsets of size k from frequent itemsets of size k-1.
3839
"""
@@ -48,7 +49,7 @@ def generate_candidates(prev_frequent: set[frozenset], k: int):
4849
return candidates
4950

5051

51-
def has_infrequent_subset(candidate: frozenset, prev_frequent: set[frozenset]):
52+
def has_infrequent_subset(candidate: frozenset, prev_frequent: set[frozenset]) -> bool:
5253
"""
5354
Apriori pruning: all (k-1)-subsets must be frequent.
5455
"""
@@ -60,7 +61,8 @@ def has_infrequent_subset(candidate: frozenset, prev_frequent: set[frozenset]):
6061

6162
# ---------- Main Apriori ----------
6263

63-
def apriori(data: list[list[str]], min_support: int):
64+
65+
def apriori(data: list[list[str]], min_support: int) -> list[tuple[frozenset, int]]:
6466
transactions = [set(t) for t in data]
6567

6668
# 1. initial 1-itemsets
@@ -70,14 +72,11 @@ def apriori(data: list[list[str]], min_support: int):
7072
item_counts[frozenset([item])] += 1
7173

7274
frequent = {
73-
itemset for itemset, count in item_counts.items()
74-
if count >= min_support
75+
itemset for itemset, count in item_counts.items() if count >= min_support
7576
}
7677

7778
all_frequents = [
78-
(next(iter(i)), c)
79-
for i, c in item_counts.items()
80-
if c >= min_support
79+
(next(iter(i)), c) for i, c in item_counts.items() if c >= min_support
8180
]
8281

8382
k = 2
@@ -87,10 +86,7 @@ def apriori(data: list[list[str]], min_support: int):
8786
candidates = generate_candidates(frequent, k)
8887

8988
# 3. prune
90-
candidates = {
91-
c for c in candidates
92-
if not has_infrequent_subset(c, frequent)
93-
}
89+
candidates = {c for c in candidates if not has_infrequent_subset(c, frequent)}
9490

9591
# 4. count support
9692
candidate_counts = defaultdict(int)
@@ -100,10 +96,7 @@ def apriori(data: list[list[str]], min_support: int):
10096
candidate_counts[c] += 1
10197

10298
# 5. filter frequent
103-
frequent = {
104-
c for c, count in candidate_counts.items()
105-
if count >= min_support
106-
}
99+
frequent = {c for c, count in candidate_counts.items() if count >= min_support}
107100

108101
all_frequents.extend(
109102
(sorted(c), count)

0 commit comments

Comments
 (0)