@@ -81,19 +81,21 @@ function PPOPolicy(;
8181 )
8282end
8383
84- function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}}, Normal} , state:: AbstractArray )
85- p. approximator. actor (send_to_device (
86- device (p. approximator),
87- state,
88- )) |> send_to_host |> StructArray{Normal}
84+ function RLBase. get_prob (
85+ p:: PPOPolicy{<:ActorCritic{<:NeuralNetworkApproximator{<:GaussianNetwork}},Normal} ,
86+ state:: AbstractArray ,
87+ )
88+ p. approximator. actor (send_to_device (device (p. approximator), state)) |>
89+ send_to_host |>
90+ StructArray{Normal}
8991end
9092
91- function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic, Categorical} , state:: AbstractArray )
92- logits = p . approximator . actor ( send_to_device (
93- device (p. approximator),
94- state,
95- )) |> softmax |> send_to_host
96- [Categorical (x;check_args= false ) for x in eachcol (logits)]
93+ function RLBase. get_prob (p:: PPOPolicy{<:ActorCritic,Categorical} , state:: AbstractArray )
94+ logits =
95+ p . approximator . actor ( send_to_device ( device (p. approximator), state)) |>
96+ softmax |>
97+ send_to_host
98+ [Categorical (x; check_args = false ) for x in eachcol (logits)]
9799end
98100
99101RLBase. get_prob (p:: PPOPolicy , env:: MultiThreadEnv ) = get_prob (p, get_state (env))
@@ -164,14 +166,14 @@ function RLBase.update!(p::PPOPolicy, t::PPOTrajectory)
164166 if AC. actor isa NeuralNetworkApproximator{<: GaussianNetwork }
165167 μ, σ = AC. actor (s)
166168 log_p′ₐ = normlogpdf (μ, σ, a)
167- entropy_loss = mean ((log (2.0f0 π)+ 1 ) / 2 .+ log .(σ))
169+ entropy_loss = mean ((log (2.0f0 π) + 1 ) / 2 .+ log .(σ))
168170 else
169171 # actor is assumed to return discrete logits
170172 logit′ = AC. actor (s)
171173 p′ = softmax (logit′)
172174 log_p′ = logsoftmax (logit′)
173175 log_p′ₐ = log_p′[CartesianIndex .(a, 1 : length (a))]
174- entropy_loss = - sum (p′ .* log_p′) * 1 // size (p′, 2 )
176+ entropy_loss = - sum (p′ .* log_p′) * 1 // size (p′, 2 )
175177 end
176178
177179 ratio = exp .(log_p′ₐ .- log_p)
@@ -198,15 +200,18 @@ function RLBase.update!(p::PPOPolicy, t::PPOTrajectory)
198200 end
199201end
200202
201- function (agent:: Agent{<:Union{PPOPolicy, RandomStartPolicy{<:PPOPolicy}}} )(:: Training{PreActStage} , env:: MultiThreadEnv )
203+ function (agent:: Agent{<:Union{PPOPolicy,RandomStartPolicy{<:PPOPolicy}}} )(
204+ :: Training{PreActStage} ,
205+ env:: MultiThreadEnv ,
206+ )
202207 state = get_state (env)
203208 dist = get_prob (agent. policy, env)
204209
205210 # currently RandomPolicy returns a Matrix instead of a (vector of) distribution.
206211 if dist isa Matrix{<: Number }
207- dist = [Categorical (x;check_args= false ) for x in eachcol (dist)]
212+ dist = [Categorical (x; check_args = false ) for x in eachcol (dist)]
208213 elseif dist isa Vector{<: Vector{<:Number} }
209- dist = [Categorical (x;check_args= false ) for x in dist]
214+ dist = [Categorical (x; check_args = false ) for x in dist]
210215 end
211216
212217 # !!! a little ugly
0 commit comments