|
24 | 24 |
|
25 | 25 | """ |
26 | 26 |
|
| 27 | +from abc import ABC, abstractmethod |
| 28 | + |
27 | 29 | import matplotlib.pyplot as plt |
28 | 30 | import numpy as np |
29 | 31 |
|
@@ -65,7 +67,32 @@ def pull(self, arm_index: int) -> int: |
65 | 67 | # Epsilon-Greedy strategy |
66 | 68 |
|
67 | 69 |
|
68 | | -class EpsilonGreedy: |
| 70 | +class Strategy(ABC): |
| 71 | + """ |
| 72 | + Base class for all strategies. |
| 73 | + """ |
| 74 | + |
| 75 | + @abstractmethod |
| 76 | + def select_arm(self) -> int: |
| 77 | + """ |
| 78 | + Select an arm to pull. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + The index of the arm to pull. |
| 82 | + """ |
| 83 | + |
| 84 | + @abstractmethod |
| 85 | + def update(self, arm_index: int, reward: int) -> None: |
| 86 | + """ |
| 87 | + Update the strategy. |
| 88 | +
|
| 89 | + Args: |
| 90 | + arm_index: The index of the arm to pull. |
| 91 | + reward: The reward for the arm. |
| 92 | + """ |
| 93 | + |
| 94 | + |
| 95 | +class EpsilonGreedy(Strategy): |
69 | 96 | """ |
70 | 97 | A class for a simple implementation of the Epsilon-Greedy strategy. |
71 | 98 | Follow this link to learn more: |
@@ -126,7 +153,7 @@ def update(self, arm_index: int, reward: int) -> None: |
126 | 153 | # Upper Confidence Bound (UCB) |
127 | 154 |
|
128 | 155 |
|
129 | | -class UCB: |
| 156 | +class UCB(Strategy): |
130 | 157 | """ |
131 | 158 | A class for the Upper Confidence Bound (UCB) strategy. |
132 | 159 | Follow this link to learn more: |
@@ -185,7 +212,7 @@ def update(self, arm_index: int, reward: int) -> None: |
185 | 212 | # Thompson Sampling |
186 | 213 |
|
187 | 214 |
|
188 | | -class ThompsonSampling: |
| 215 | +class ThompsonSampling(Strategy): |
189 | 216 | """ |
190 | 217 | A class for the Thompson Sampling strategy. |
191 | 218 | Follow this link to learn more: |
@@ -245,7 +272,7 @@ def update(self, arm_index: int, reward: int) -> None: |
245 | 272 |
|
246 | 273 |
|
247 | 274 | # Random strategy (full exploration) |
248 | | -class RandomStrategy: |
| 275 | +class RandomStrategy(Strategy): |
249 | 276 | """ |
250 | 277 | A class for choosing totally random at each round to give |
251 | 278 | a better comparison with the other optimised strategies. |
@@ -292,7 +319,7 @@ def update(self, arm_index: int, reward: int) -> None: |
292 | 319 | # Greedy strategy (full exploitation) |
293 | 320 |
|
294 | 321 |
|
295 | | -class GreedyStrategy: |
| 322 | +class GreedyStrategy(Strategy): |
296 | 323 | """ |
297 | 324 | A class for the Greedy strategy to show how full exploitation can be |
298 | 325 | detrimental to the performance of the strategy. |
@@ -351,7 +378,7 @@ def test_mab_strategies() -> None: |
351 | 378 | arms_probabilities = [0.1, 0.3, 0.5, 0.8] # True probabilities |
352 | 379 |
|
353 | 380 | bandit = Bandit(arms_probabilities) |
354 | | - strategies = { |
| 381 | + strategies: dict[str, Strategy] = { |
355 | 382 | "Epsilon-Greedy": EpsilonGreedy(epsilon=0.1, num_arms=num_arms), |
356 | 383 | "UCB": UCB(num_arms=num_arms), |
357 | 384 | "Thompson Sampling": ThompsonSampling(num_arms=num_arms), |
|
0 commit comments