Skip to content

Commit eaf87c6

Browse files
weijiangweijiang
authored andcommitted
enhance knapsack problem
1 parent cfabd91 commit eaf87c6

3 files changed

Lines changed: 46 additions & 26 deletions

File tree

knapsack/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# A naive recursive implementation of 0-1 Knapsack Problem
1+
# A recursive implementation of 0-N Knapsack Problem
22

33
This overview is taken from:
44

knapsack/knapsack.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,54 @@
1-
""" A naive recursive implementation of 0-1 Knapsack Problem
1+
""" A recursive implementation of 0-N Knapsack Problem
22
https://en.wikipedia.org/wiki/Knapsack_problem
33
"""
44
from __future__ import annotations
5+
from functools import lru_cache
56

67

7-
def knapsack(capacity: int, weights: list[int], values: list[int], counter: int) -> int:
8+
def knapsack(capacity: int, weights: list[int], values: list[int], counter: int, allow_repetition=False) -> int:
89
"""
910
Returns the maximum value that can be put in a knapsack of a capacity cap,
10-
whereby each weight w has a specific value val.
11+
whereby each weight w has a specific value val with option to allow repetitive selection of items
1112
1213
>>> cap = 50
1314
>>> val = [60, 100, 120]
1415
>>> w = [10, 20, 30]
1516
>>> c = len(val)
16-
>>> knapsack(cap, w, val, c)
17+
>>> knapsack(cap, w, val, c, False)
1718
220
1819
19-
The result is 220 cause the values of 100 and 120 got the weight of 50
20+
Given the repetition is NOT allowed,
21+
the result is 220 cause the values of 100 and 120 got the weight of 50
2022
which is the limit of the capacity.
21-
"""
23+
>>> knapsack(cap, w, val, c, True)
24+
300
2225
23-
# Base Case
24-
if counter == 0 or capacity == 0:
25-
return 0
26-
27-
# If weight of the nth item is more than Knapsack of capacity,
28-
# then this item cannot be included in the optimal solution,
29-
# else return the maximum of two cases:
30-
# (1) nth item included
31-
# (2) not included
32-
if weights[counter - 1] > capacity:
33-
return knapsack(capacity, weights, values, counter - 1)
34-
else:
35-
left_capacity = capacity - weights[counter - 1]
36-
new_value_included = values[counter - 1] + knapsack(
37-
left_capacity, weights, values, counter - 1
38-
)
39-
without_new_value = knapsack(capacity, weights, values, counter - 1)
40-
return max(new_value_included, without_new_value)
26+
Given the repetition is allowed,
27+
tthe result is 300 cause the values of 60*5 (pick 5 times)
28+
which is the limit of the capacity.
29+
"""
30+
@lru_cache()
31+
def knapsack_recur(cap: int, c: int) -> int:
32+
# Base Case
33+
if c == 0 or cap == 0:
34+
return 0
35+
36+
# If weight of the nth item is more than Knapsack of capacity,
37+
# then this item cannot be included in the optimal solution,
38+
# else return the maximum of two cases:
39+
# (1) not included
40+
# (2) nth item included one or more times (0-N), if allow_repetition is true
41+
# nth item included only once (0-1), if allow_repetition is false
42+
if weights[c - 1] > cap:
43+
return knapsack_recur(cap, c - 1)
44+
else:
45+
without_new_value = knapsack_recur(cap, c - 1)
46+
if allow_repetition:
47+
new_value_included = values[c - 1] + knapsack_recur(cap - weights[c - 1], c)
48+
else:
49+
new_value_included = values[c - 1] + knapsack_recur(cap - weights[c - 1], c - 1)
50+
return max(new_value_included, without_new_value)
51+
return knapsack_recur(capacity, counter)
4152

4253

4354
if __name__ == "__main__":

knapsack/tests/test_knapsack.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,17 @@ def test_knapsack(self):
4545
val = [60, 100, 120]
4646
w = [10, 20, 30]
4747
c = len(val)
48-
self.assertEqual(k.knapsack(cap, w, val, c), 220)
48+
self.assertEqual(k.knapsack(cap, w, val, c, False), 220)
4949

50+
def test_knapsack_repetition(self):
51+
"""
52+
test for the knapsack
53+
"""
54+
cap = 50
55+
val = [60, 100, 120]
56+
w = [10, 20, 30]
57+
c = len(val)
58+
self.assertEqual(k.knapsack(cap, w, val, c, True), 300)
5059

5160
if __name__ == "__main__":
5261
unittest.main()

0 commit comments

Comments
 (0)