@@ -22,8 +22,8 @@ function BestResponsePolicy(
2222 state_type = String,
2323 action_type = Int,
2424)
25- # S = typeof(get_state (env)) # TODO : currently it will break the OpenSpielEnv. Can not get information set for chance player
26- # A = eltype(get_actions (env)) # TODO : for chance players it will return ActionProbPair
25+ # S = typeof(state (env)) # TODO : currently it will break the OpenSpielEnv. Can not get information set for chance player
26+ # A = eltype(action_space (env)) # TODO : for chance players it will return ActionProbPair
2727 S = state_type
2828 A = action_type
2929 E = typeof (env)
@@ -45,31 +45,31 @@ function BestResponsePolicy(
4545end
4646
4747function (p:: BestResponsePolicy )(env:: AbstractEnv )
48- if get_current_player (env) == p. best_responder
48+ if current_player (env) == p. best_responder
4949 best_response_action (p, env)
5050 else
5151 p. policy (env)
5252 end
5353end
5454
5555function init_cfr_reach_prob! (p, env, reach_prob = 1.0 )
56- if ! get_terminal (env)
57- if get_current_player (env) == p. best_responder
58- push! (get! (p. cfr_reach_prob, get_state (env), []), env => reach_prob)
56+ if ! is_terminated (env)
57+ if current_player (env) == p. best_responder
58+ push! (get! (p. cfr_reach_prob, state (env), []), env => reach_prob)
5959
60- for a in get_legal_actions (env)
60+ for a in legal_action_space (env)
6161 init_cfr_reach_prob! (p, child (env, a), reach_prob)
6262 end
63- elseif get_current_player (env) == get_chance_player (env)
64- for a:: ActionProbPair in get_actions (env)
63+ elseif current_player (env) == chance_player (env)
64+ for a:: ActionProbPair in action_space (env)
6565 init_cfr_reach_prob! (p, child (env, a), reach_prob * a. prob)
6666 end
6767 else # opponents
68- for a in get_legal_actions (env)
68+ for a in legal_action_space (env)
6969 init_cfr_reach_prob! (
7070 p,
7171 child (env, a),
72- reach_prob * get_prob (p. policy, env, a),
72+ reach_prob * prob (p. policy, env, a),
7373 )
7474 end
7575 end
7878
7979function best_response_value (p, env)
8080 get! (p. best_response_value_cache, env) do
81- if get_terminal (env)
82- get_reward (env, p. best_responder)
83- elseif get_current_player (env) == p. best_responder
81+ if is_terminated (env)
82+ reward (env, p. best_responder)
83+ elseif current_player (env) == p. best_responder
8484 a = best_response_action (p, env)
8585 best_response_value (p, child (env, a))
86- elseif get_current_player (env) == get_chance_player (env)
86+ elseif current_player (env) == chance_player (env)
8787 v = 0.0
88- for a:: ActionProbPair in get_actions (env)
88+ for a:: ActionProbPair in action_space (env)
8989 v += a. prob * best_response_value (p, child (env, a))
9090 end
9191 v
9292 else
9393 v = 0.0
94- for a in get_legal_actions (env)
95- v += get_prob (p. policy, env, a) * best_response_value (p, child (env, a))
94+ for a in legal_action_space (env)
95+ v += prob (p. policy, env, a) * best_response_value (p, child (env, a))
9696 end
9797 v
9898 end
9999 end
100100end
101101
102102function best_response_action (p, env)
103- get! (p. best_response_action_cache, get_state (env)) do
103+ get! (p. best_response_action_cache, state (env)) do
104104 best_action, best_action_value = nothing , typemin (Float64)
105- for a in get_legal_actions (env)
106- # for each information set (`get_state (env)` here), we may have several paths to reach it
105+ for a in legal_action_space (env)
106+ # for each information set (`state (env)` here), we may have several paths to reach it
107107 # here we sum the cfr reach prob weighted value to find out the best action
108- v = sum (p. cfr_reach_prob[get_state (env)]) do (e, reach_prob)
108+ v = sum (p. cfr_reach_prob[state (env)]) do (e, reach_prob)
109109 reach_prob * best_response_value (p, child (e, a))
110110 end
111111 if v > best_action_value
@@ -118,10 +118,10 @@ end
118118
119119RLBase. update! (p:: BestResponsePolicy , args... ) = nothing
120120
121- function RLBase. get_prob (p:: BestResponsePolicy , env:: AbstractEnv )
122- if get_current_player (env) == p. best_responder
123- onehot (p (env), get_actions (env))
121+ function RLBase. prob (p:: BestResponsePolicy , env:: AbstractEnv )
122+ if current_player (env) == p. best_responder
123+ onehot (p (env), action_space (env))
124124 else
125- get_prob (p. policy, env)
125+ prob (p. policy, env)
126126 end
127127end
0 commit comments