@@ -41,7 +41,7 @@ def __init__(self, probabilities: list[float]) -> None:
4141 probabilities: List of probabilities for each arm.
4242 """
4343 self .probabilities = probabilities
44- self .k = len (probabilities )
44+ self .num_arms = len (probabilities )
4545
4646 def pull (self , arm_index : int ) -> int :
4747 """
@@ -72,18 +72,18 @@ class EpsilonGreedy:
7272 https://medium.com/analytics-vidhya/the-epsilon-greedy-algorithm-for-reinforcement-learning-5fe6f96dc870
7373 """
7474
75- def __init__ (self , epsilon : float , k : int ) -> None :
75+ def __init__ (self , epsilon : float , num_arms : int ) -> None :
7676 """
7777 Initialize the Epsilon-Greedy strategy.
7878
7979 Args:
8080 epsilon: The probability of exploring new arms.
81- k : The number of arms.
81+ num_arms : The number of arms.
8282 """
8383 self .epsilon = epsilon
84- self .k = k
85- self .counts = np .zeros (k )
86- self .values = np .zeros (k )
84+ self .num_arms = num_arms
85+ self .counts = np .zeros (num_arms )
86+ self .values = np .zeros (num_arms )
8787
8888 def select_arm (self ) -> int :
8989 """
@@ -93,14 +93,14 @@ def select_arm(self) -> int:
9393 The index of the arm to pull.
9494
9595 Example:
96- >>> strategy = EpsilonGreedy(epsilon=0.1, k =3)
96+ >>> strategy = EpsilonGreedy(epsilon=0.1, num_arms =3)
9797 >>> 0 <= strategy.select_arm() < 3
9898 np.True_
9999 """
100100 rng = np .random .default_rng ()
101101
102102 if rng .random () < self .epsilon :
103- return rng .integers (self .k )
103+ return rng .integers (self .num_arms )
104104 else :
105105 return np .argmax (self .values )
106106
@@ -113,7 +113,7 @@ def update(self, arm_index: int, reward: int) -> None:
113113 reward: The reward for the arm.
114114
115115 Example:
116- >>> strategy = EpsilonGreedy(epsilon=0.1, k =3)
116+ >>> strategy = EpsilonGreedy(epsilon=0.1, num_arms =3)
117117 >>> strategy.update(0, 1)
118118 >>> strategy.counts[0] == 1
119119 np.True_
@@ -133,16 +133,16 @@ class UCB:
133133 https://people.maths.bris.ac.uk/~maajg/teaching/stochopt/ucb.pdf
134134 """
135135
136- def __init__ (self , k : int ) -> None :
136+ def __init__ (self , num_arms : int ) -> None :
137137 """
138138 Initialize the UCB strategy.
139139
140140 Args:
141- k : The number of arms.
141+ num_arms : The number of arms.
142142 """
143- self .k = k
144- self .counts = np .zeros (k )
145- self .values = np .zeros (k )
143+ self .num_arms = num_arms
144+ self .counts = np .zeros (num_arms )
145+ self .values = np .zeros (num_arms )
146146 self .total_counts = 0
147147
148148 def select_arm (self ) -> int :
@@ -153,13 +153,14 @@ def select_arm(self) -> int:
153153 The index of the arm to pull.
154154
155155 Example:
156- >>> strategy = UCB(k =3)
156+ >>> strategy = UCB(num_arms =3)
157157 >>> 0 <= strategy.select_arm() < 3
158158 True
159159 """
160- if self .total_counts < self .k :
160+ if self .total_counts < self .num_arms :
161161 return self .total_counts
162- ucb_values = self .values + np .sqrt (2 * np .log (self .total_counts ) / self .counts )
162+ ucb_values = self .values + \
163+ np .sqrt (2 * np .log (self .total_counts ) / self .counts )
163164 return np .argmax (ucb_values )
164165
165166 def update (self , arm_index : int , reward : int ) -> None :
@@ -171,7 +172,7 @@ def update(self, arm_index: int, reward: int) -> None:
171172 reward: The reward for the arm.
172173
173174 Example:
174- >>> strategy = UCB(k =3)
175+ >>> strategy = UCB(num_arms =3)
175176 >>> strategy.update(0, 1)
176177 >>> strategy.counts[0] == 1
177178 np.True_
@@ -192,16 +193,16 @@ class ThompsonSampling:
192193 https://en.wikipedia.org/wiki/Thompson_sampling
193194 """
194195
195- def __init__ (self , k : int ) -> None :
196+ def __init__ (self , num_arms : int ) -> None :
196197 """
197198 Initialize the Thompson Sampling strategy.
198199
199200 Args:
200- k : The number of arms.
201+ num_arms : The number of arms.
201202 """
202- self .k = k
203- self .successes = np .zeros (k )
204- self .failures = np .zeros (k )
203+ self .num_arms = num_arms
204+ self .successes = np .zeros (num_arms )
205+ self .failures = np .zeros (num_arms )
205206
206207 def select_arm (self ) -> int :
207208 """
@@ -212,14 +213,15 @@ def select_arm(self) -> int:
212213 which relies on the Beta distribution.
213214
214215 Example:
215- >>> strategy = ThompsonSampling(k =3)
216+ >>> strategy = ThompsonSampling(num_arms =3)
216217 >>> 0 <= strategy.select_arm() < 3
217218 np.True_
218219 """
219220 rng = np .random .default_rng ()
220221
221222 samples = [
222- rng .beta (self .successes [i ] + 1 , self .failures [i ] + 1 ) for i in range (self .k )
223+ rng .beta (self .successes [i ] + 1 , self .failures [i ] + 1 )
224+ for i in range (self .num_arms )
223225 ]
224226 return np .argmax (samples )
225227
@@ -232,7 +234,7 @@ def update(self, arm_index: int, reward: int) -> None:
232234 reward: The reward for the arm.
233235
234236 Example:
235- >>> strategy = ThompsonSampling(k =3)
237+ >>> strategy = ThompsonSampling(num_arms =3)
236238 >>> strategy.update(0, 1)
237239 >>> strategy.successes[0] == 1
238240 np.True_
@@ -250,14 +252,14 @@ class RandomStrategy:
250252 a better comparison with the other optimised strategies.
251253 """
252254
253- def __init__ (self , k : int ):
255+ def __init__ (self , num_arms : int ) -> None :
254256 """
255257 Initialize the Random strategy.
256258
257259 Args:
258- k : The number of arms.
260+ num_arms : The number of arms.
259261 """
260- self .k = k
262+ self .num_arms = num_arms
261263
262264 def select_arm (self ) -> int :
263265 """
@@ -267,12 +269,12 @@ def select_arm(self) -> int:
267269 The index of the arm to pull.
268270
269271 Example:
270- >>> strategy = RandomStrategy(k =3)
272+ >>> strategy = RandomStrategy(num_arms =3)
271273 >>> 0 <= strategy.select_arm() < 3
272274 np.True_
273275 """
274276 rng = np .random .default_rng ()
275- return rng .integers (self .k )
277+ return rng .integers (self .num_arms )
276278
277279 def update (self , arm_index : int , reward : int ) -> None :
278280 """
@@ -283,7 +285,7 @@ def update(self, arm_index: int, reward: int) -> None:
283285 reward: The reward for the arm.
284286
285287 Example:
286- >>> strategy = RandomStrategy(k =3)
288+ >>> strategy = RandomStrategy(num_arms =3)
287289 >>> strategy.update(0, 1)
288290 """
289291
@@ -297,16 +299,16 @@ class GreedyStrategy:
297299 detrimental to the performance of the strategy.
298300 """
299301
300- def __init__ (self , k : int ):
302+ def __init__ (self , num_arms : int ) -> None :
301303 """
302304 Initialize the Greedy strategy.
303305
304306 Args:
305- k : The number of arms.
307+ num_arms : The number of arms.
306308 """
307- self .k = k
308- self .counts = np .zeros (k )
309- self .values = np .zeros (k )
309+ self .num_arms = num_arms
310+ self .counts = np .zeros (num_arms )
311+ self .values = np .zeros (num_arms )
310312
311313 def select_arm (self ) -> int :
312314 """
@@ -316,7 +318,7 @@ def select_arm(self) -> int:
316318 The index of the arm to pull.
317319
318320 Example:
319- >>> strategy = GreedyStrategy(k =3)
321+ >>> strategy = GreedyStrategy(num_arms =3)
320322 >>> 0 <= strategy.select_arm() < 3
321323 np.True_
322324 """
@@ -331,7 +333,7 @@ def update(self, arm_index: int, reward: int) -> None:
331333 reward: The reward for the arm.
332334
333335 Example:
334- >>> strategy = GreedyStrategy(k =3)
336+ >>> strategy = GreedyStrategy(num_arms =3)
335337 >>> strategy.update(0, 1)
336338 >>> strategy.counts[0] == 1
337339 np.True_
@@ -346,16 +348,16 @@ def test_mab_strategies() -> None:
346348 Test the MAB strategies.
347349 """
348350 # Simulation
349- k = 4
351+ num_arms = 4
350352 arms_probabilities = [0.1 , 0.3 , 0.5 , 0.8 ] # True probabilities
351353
352354 bandit = Bandit (arms_probabilities )
353355 strategies = {
354- "Epsilon-Greedy" : EpsilonGreedy (epsilon = 0.1 , k = k ),
355- "UCB" : UCB (k = k ),
356- "Thompson Sampling" : ThompsonSampling (k = k ),
357- "Full Exploration(Random)" : RandomStrategy (k = k ),
358- "Full Exploitation(Greedy)" : GreedyStrategy (k = k ),
356+ "Epsilon-Greedy" : EpsilonGreedy (epsilon = 0.1 , num_arms = num_arms ),
357+ "UCB" : UCB (num_arms = num_arms ),
358+ "Thompson Sampling" : ThompsonSampling (num_arms = num_arms ),
359+ "Full Exploration(Random)" : RandomStrategy (num_arms = num_arms ),
360+ "Full Exploitation(Greedy)" : GreedyStrategy (num_arms = num_arms ),
359361 }
360362
361363 num_rounds = 1000
0 commit comments