Skip to content

Commit 0c6a251

Browse files
Refactored Apriori implementation with correct pruning and candidate generation
Brief: Improved pruning logic and fixed core support count in Apriori function. Description: Rewrote prune logic and fixed key issues in candidate generation to ensure accurate itemset frequency counting, pruning, and ordering for output. Explanation: 1. Rewrote the `prune` function to validate (k-1)-subsets correctly. 2. Previous version misused list and count logic in pruning process. 3. Candidate generation now uses proper set union to join k-itemsets. 4. Added conversion from set of frozensets to deduplicate candidates safely. 5. Fixed incorrect initial support counting by replacing flawed loop logic. 6. Output of `apriori` is now consistently sorted for testing and readability. 7. Updated doctests to match new and correct support count outputs. Conclusion: This change corrects both logic and structure of the Apriori algorithm, ensuring reliable pruning, accurate support calculation, and stable output format. It also resolves structural design issues in candidate creation, making the code more maintainable and testable. The refactor is essential for correctness and scaling.
1 parent 0a3a965 commit 0c6a251

1 file changed

Lines changed: 66 additions & 56 deletions

File tree

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Apriori Algorithm is a Association rule mining technique, also known as market basket
2+
Apriori Algorithm is an Association rule mining technique, also known as market basket
33
analysis, aims to discover interesting relationships or associations among a set of
44
items in a transactional or relational database.
55
@@ -12,6 +12,7 @@
1212
"""
1313

1414
from itertools import combinations
15+
from collections import defaultdict
1516

1617

1718
def load_data() -> list[list[str]]:
@@ -24,36 +25,28 @@ def load_data() -> list[list[str]]:
2425
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]
2526

2627

27-
def prune(itemset: list, candidates: list, length: int) -> list:
28+
def prune(frequent_itemsets: list[list[str]], candidates: list[list[str]]) -> list[list[str]]:
2829
"""
29-
Prune candidate itemsets that are not frequent.
30-
The goal of pruning is to filter out candidate itemsets that are not frequent. This
31-
is done by checking if all the (k-1) subsets of a candidate itemset are present in
32-
the frequent itemsets of the previous iteration (valid subsequences of the frequent
33-
itemsets from the previous iteration).
34-
35-
Prunes candidate itemsets that are not frequent.
36-
37-
>>> itemset = ['X', 'Y', 'Z']
38-
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
39-
>>> prune(itemset, candidates, 2)
40-
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
41-
42-
>>> itemset = ['1', '2', '3', '4']
43-
>>> candidates = ['1', '2', '4']
44-
>>> prune(itemset, candidates, 3)
45-
[]
30+
Prunes candidate itemsets by ensuring all (k-1)-subsets exist in previous frequent itemsets.
31+
32+
>>> frequent_itemsets = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
33+
>>> candidates = [['X', 'Y', 'Z'], ['X', 'Y', 'W']]
34+
>>> prune(frequent_itemsets, candidates)
35+
[['X', 'Y', 'Z']]
4636
"""
47-
pruned = []
37+
38+
previous_frequents = set(frozenset(itemset) for itemset in frequent_itemsets)
39+
40+
pruned_candidates = []
4841
for candidate in candidates:
49-
is_subsequence = True
50-
for item in candidate:
51-
if item not in itemset or itemset.count(item) < length - 1:
52-
is_subsequence = False
53-
break
54-
if is_subsequence:
55-
pruned.append(candidate)
56-
return pruned
42+
all_subsets_frequent = all(
43+
frozenset(subset) in previous_frequents
44+
for subset in combinations(candidate, len(candidate) - 1)
45+
)
46+
if all_subsets_frequent:
47+
pruned_candidates.append(candidate)
48+
49+
return pruned_candidates
5750

5851

5952
def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
@@ -62,52 +55,69 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
6255
6356
>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
6457
>>> apriori(data, 2)
65-
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
58+
[(['A'], 4), (['B'], 3), (['C'], 3), (['A', 'B'], 2), (['A', 'C'], 2), (['B', 'C'], 2)]
6659
6760
>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
6861
>>> apriori(data, 3)
69-
[]
62+
[(['1'], 4), (['2'], 3), (['3'], 3)]
7063
"""
71-
itemset = [list(transaction) for transaction in data]
72-
frequent_itemsets = []
73-
length = 1
7464

75-
while itemset:
76-
# Count itemset support
77-
counts = [0] * len(itemset)
78-
for transaction in data:
79-
for j, candidate in enumerate(itemset):
80-
if all(item in transaction for item in candidate):
81-
counts[j] += 1
65+
item_counts = defaultdict(int)
66+
for transaction in data:
67+
for item in transaction:
68+
item_counts[item] += 1
69+
70+
current_frequents = [[item] for item, count in item_counts.items() if count >= min_support]
71+
frequent_itemsets = [([item], count) for item, count in item_counts.items() if count >= min_support]
8272

83-
# Prune infrequent itemsets
84-
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
73+
k = 2
74+
while current_frequents:
75+
candidates = [sorted(list(set(i) | set(j)))
76+
for i in current_frequents
77+
for j in current_frequents
78+
if len(set(i).union(j)) == k]
8579

86-
# Append frequent itemsets (as a list to maintain order)
87-
for i, item in enumerate(itemset):
88-
frequent_itemsets.append((sorted(item), counts[i]))
80+
candidates = [list(c) for c in {frozenset(c) for c in candidates}]
8981

90-
length += 1
91-
itemset = prune(itemset, list(combinations(itemset, length)), length)
82+
candidates = prune(current_frequents, candidates)
9283

93-
return frequent_itemsets
84+
candidate_counts = defaultdict(int)
85+
for transaction in data:
86+
t_set = set(transaction)
87+
for candidate in candidates:
88+
if set(candidate).issubset(t_set):
89+
candidate_counts[tuple(sorted(candidate))] += 1
90+
91+
current_frequents = [list(key) for key, count in candidate_counts.items() if count >= min_support]
92+
frequent_itemsets.extend(
93+
[
94+
(list(key), count) for key, count in candidate_counts.items() if count >= min_support
95+
]
96+
)
97+
98+
k += 1
99+
100+
return sorted(frequent_itemsets, key=lambda x: (len(x[0]), x[0]))
94101

95102

96103
if __name__ == "__main__":
97104
"""
98105
Apriori algorithm for finding frequent itemsets.
99106
100-
Args:
101-
data: A list of transactions, where each transaction is a list of items.
102-
min_support: The minimum support threshold for frequent itemsets.
107+
This script loads sample transaction data and runs the Apriori algorithm
108+
with a user-defined minimum support threshold.
103109
104-
Returns:
105-
A list of frequent itemsets along with their support counts.
110+
The result is a list of frequent itemsets along with their support counts.
106111
"""
107112
import doctest
108113

109114
doctest.testmod()
110115

111-
# user-defined threshold or minimum support level
112-
frequent_itemsets = apriori(data=load_data(), min_support=2)
113-
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
116+
transactions = load_data()
117+
min_support_threshold = 2
118+
119+
frequent_itemsets = apriori(transactions, min_support=min_support_threshold)
120+
121+
print("Frequent Itemsets:")
122+
for itemset, support in frequent_itemsets:
123+
print(f"{itemset}: {support}")

0 commit comments

Comments
 (0)