From 0c6a25144794118b510efa90e33e8a5c2d20f802 Mon Sep 17 00:00:00 2001 From: Fardin Moghaddam Pour Date: Wed, 30 Apr 2025 11:31:46 +0330 Subject: [PATCH 1/2] Refactored Apriori implementation with correct pruning and candidate generation Brief: Improved pruning logic and fixed core support count in Apriori function. Description: Rewrote prune logic and fixed key issues in candidate generation to ensure accurate itemset frequency counting, pruning, and ordering for output. Explanation: 1. Rewrote the `prune` function to validate (k-1)-subsets correctly. 2. Previous version misused list and count logic in pruning process. 3. Candidate generation now uses proper set union to join k-itemsets. 4. Added conversion from set of frozensets to deduplicate candidates safely. 5. Fixed incorrect initial support counting by replacing flawed loop logic. 6. Output of `apriori` is now consistently sorted for testing and readability. 7. Updated doctests to match new and correct support count outputs. Conclusion: This change corrects both logic and structure of the Apriori algorithm, ensuring reliable pruning, accurate support calculation, and stable output format. It also resolves structural design issues in candidate creation, making the code more maintainable and testable. The refactor is essential for correctness and scaling. --- machine_learning/apriori_algorithm.py | 122 ++++++++++++++------------ 1 file changed, 66 insertions(+), 56 deletions(-) diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 09a89ac236bd..876363392095 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -1,5 +1,5 @@ """ -Apriori Algorithm is a Association rule mining technique, also known as market basket +Apriori Algorithm is an Association rule mining technique, also known as market basket analysis, aims to discover interesting relationships or associations among a set of items in a transactional or relational database. @@ -12,6 +12,7 @@ """ from itertools import combinations +from collections import defaultdict def load_data() -> list[list[str]]: @@ -24,36 +25,28 @@ def load_data() -> list[list[str]]: return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]] -def prune(itemset: list, candidates: list, length: int) -> list: +def prune(frequent_itemsets: list[list[str]], candidates: list[list[str]]) -> list[list[str]]: """ - Prune candidate itemsets that are not frequent. - The goal of pruning is to filter out candidate itemsets that are not frequent. This - is done by checking if all the (k-1) subsets of a candidate itemset are present in - the frequent itemsets of the previous iteration (valid subsequences of the frequent - itemsets from the previous iteration). - - Prunes candidate itemsets that are not frequent. - - >>> itemset = ['X', 'Y', 'Z'] - >>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] - >>> prune(itemset, candidates, 2) - [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] - - >>> itemset = ['1', '2', '3', '4'] - >>> candidates = ['1', '2', '4'] - >>> prune(itemset, candidates, 3) - [] + Prunes candidate itemsets by ensuring all (k-1)-subsets exist in previous frequent itemsets. + + >>> frequent_itemsets = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] + >>> candidates = [['X', 'Y', 'Z'], ['X', 'Y', 'W']] + >>> prune(frequent_itemsets, candidates) + [['X', 'Y', 'Z']] """ - pruned = [] + + previous_frequents = set(frozenset(itemset) for itemset in frequent_itemsets) + + pruned_candidates = [] for candidate in candidates: - is_subsequence = True - for item in candidate: - if item not in itemset or itemset.count(item) < length - 1: - is_subsequence = False - break - if is_subsequence: - pruned.append(candidate) - return pruned + all_subsets_frequent = all( + frozenset(subset) in previous_frequents + for subset in combinations(candidate, len(candidate) - 1) + ) + if all_subsets_frequent: + pruned_candidates.append(candidate) + + return pruned_candidates def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]: @@ -62,52 +55,69 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in >>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']] >>> apriori(data, 2) - [(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)] + [(['A'], 4), (['B'], 3), (['C'], 3), (['A', 'B'], 2), (['A', 'C'], 2), (['B', 'C'], 2)] >>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']] >>> apriori(data, 3) - [] + [(['1'], 4), (['2'], 3), (['3'], 3)] """ - itemset = [list(transaction) for transaction in data] - frequent_itemsets = [] - length = 1 - while itemset: - # Count itemset support - counts = [0] * len(itemset) - for transaction in data: - for j, candidate in enumerate(itemset): - if all(item in transaction for item in candidate): - counts[j] += 1 + item_counts = defaultdict(int) + for transaction in data: + for item in transaction: + item_counts[item] += 1 + + current_frequents = [[item] for item, count in item_counts.items() if count >= min_support] + frequent_itemsets = [([item], count) for item, count in item_counts.items() if count >= min_support] - # Prune infrequent itemsets - itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support] + k = 2 + while current_frequents: + candidates = [sorted(list(set(i) | set(j))) + for i in current_frequents + for j in current_frequents + if len(set(i).union(j)) == k] - # Append frequent itemsets (as a list to maintain order) - for i, item in enumerate(itemset): - frequent_itemsets.append((sorted(item), counts[i])) + candidates = [list(c) for c in {frozenset(c) for c in candidates}] - length += 1 - itemset = prune(itemset, list(combinations(itemset, length)), length) + candidates = prune(current_frequents, candidates) - return frequent_itemsets + candidate_counts = defaultdict(int) + for transaction in data: + t_set = set(transaction) + for candidate in candidates: + if set(candidate).issubset(t_set): + candidate_counts[tuple(sorted(candidate))] += 1 + + current_frequents = [list(key) for key, count in candidate_counts.items() if count >= min_support] + frequent_itemsets.extend( + [ + (list(key), count) for key, count in candidate_counts.items() if count >= min_support + ] + ) + + k += 1 + + return sorted(frequent_itemsets, key=lambda x: (len(x[0]), x[0])) if __name__ == "__main__": """ Apriori algorithm for finding frequent itemsets. - Args: - data: A list of transactions, where each transaction is a list of items. - min_support: The minimum support threshold for frequent itemsets. + This script loads sample transaction data and runs the Apriori algorithm + with a user-defined minimum support threshold. - Returns: - A list of frequent itemsets along with their support counts. + The result is a list of frequent itemsets along with their support counts. """ import doctest doctest.testmod() - # user-defined threshold or minimum support level - frequent_itemsets = apriori(data=load_data(), min_support=2) - print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets)) + transactions = load_data() + min_support_threshold = 2 + + frequent_itemsets = apriori(transactions, min_support=min_support_threshold) + + print("Frequent Itemsets:") + for itemset, support in frequent_itemsets: + print(f"{itemset}: {support}") From afdd40b3e6a9e3be66c1039170112e4a7eb4901f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Apr 2025 08:09:38 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- machine_learning/apriori_algorithm.py | 30 +++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 876363392095..78af289eab90 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -25,7 +25,9 @@ def load_data() -> list[list[str]]: return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]] -def prune(frequent_itemsets: list[list[str]], candidates: list[list[str]]) -> list[list[str]]: +def prune( + frequent_itemsets: list[list[str]], candidates: list[list[str]] +) -> list[list[str]]: """ Prunes candidate itemsets by ensuring all (k-1)-subsets exist in previous frequent itemsets. @@ -67,15 +69,21 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in for item in transaction: item_counts[item] += 1 - current_frequents = [[item] for item, count in item_counts.items() if count >= min_support] - frequent_itemsets = [([item], count) for item, count in item_counts.items() if count >= min_support] + current_frequents = [ + [item] for item, count in item_counts.items() if count >= min_support + ] + frequent_itemsets = [ + ([item], count) for item, count in item_counts.items() if count >= min_support + ] k = 2 while current_frequents: - candidates = [sorted(list(set(i) | set(j))) - for i in current_frequents - for j in current_frequents - if len(set(i).union(j)) == k] + candidates = [ + sorted(list(set(i) | set(j))) + for i in current_frequents + for j in current_frequents + if len(set(i).union(j)) == k + ] candidates = [list(c) for c in {frozenset(c) for c in candidates}] @@ -88,10 +96,14 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in if set(candidate).issubset(t_set): candidate_counts[tuple(sorted(candidate))] += 1 - current_frequents = [list(key) for key, count in candidate_counts.items() if count >= min_support] + current_frequents = [ + list(key) for key, count in candidate_counts.items() if count >= min_support + ] frequent_itemsets.extend( [ - (list(key), count) for key, count in candidate_counts.items() if count >= min_support + (list(key), count) + for key, count in candidate_counts.items() + if count >= min_support ] )