To preface, I am a complete Julia newbie… I am trying to implement PPO for the first time and I’ve been having issues updating the actor (and by extension critic) network parameters using the gradient from Flux.jl.
Below is just a small snippet of my code. I think it should be sufficient (?), but if not please let me know and I am happy to provide.
batch_obs, batch_acts, _, batch_rtgo, _ = rollout(ppo_network)
V, curr_log_probs = evaluate(ppo_network, batch_obs, batch_acts)
V_scaled = V*maximum(abs.(batch_rtgo)) / maximum(abs.(V))
ratios = exp.(curr_log_probs - transpose(hcat(old_batch_log_probs...)))
ratios = ratios .+ 1e-8
A_k = batch_rtgo - deepcopy(V_scaled)
A_k = (A_k .- mean(A_k)) ./ (std(A_k) .+ 1e-10)
# surrogate objectives
surr1 = ratios .* A_k
surr2 = clamp.(ratios, 1-clip, 1+clip) .* A_k
actor_loss = -mean(min.(surr1, surr2))
actor_opt = Adam(lr)
actor_gs = gradient(() -> actor_loss, params(ppo.actor.model, ppo.actor.mean, ppo.actor.logstd))
# update parameters
update!(actor_opt, params([ppo.actor.model, ppo.actor.mean, ppo.actor.logstd]), actor_gs)
I checked the gradient and find that the Dict has keys while all values are nothing.
How can I properly define the gradient so that I can perform this update?
I did some research and found that the loss function requires the parameters to be passed into it, otherwise the gradient does not know which parameters to take the gradient w.r.t and thus ends up returning nothing… Is this correct?
I tried to create a custom loss function like this which did seem to get me nonzero gradients, but the problem is that I need to run the rollout and evaluate which can be costly. It also runs into issues with large named tuples as the batch size of the rollout becomes large.
# define actor parameters
actor_ps = params(ppo.actor.model, ppo.actor.mean, ppo.actor.logstd)
# define gradient function
actor_gs = gradient(actor_ps) do
batch_obs, batch_acts, _, batch_rtgo, _ = rollout(ppo_network)
V, curr_log_probs = evaluate(ppo_network, batch_obs, batch_acts)
V_scaled = V*maximum(abs.(batch_rtgo)) / maximum(abs.(V))
ratios = exp.(curr_log_probs - transpose(hcat(old_batch_log_probs...)))
ratios = ratios .+ 1e-8
A_k = batch_rtgo - deepcopy(V_scaled)
A_k = (A_k .- mean(A_k)) ./ (std(A_k) .+ 1e-10)
surr1 = ratios .* A_k
surr2 = clamp.(ratios, 1-clip, 1+clip) .* A_k
actor_loss = -mean(min.(surr1, surr2))
return actor_loss
end
Max Kim is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.