In my last post I had tried to implement a DQN model for a Chess bot, progressing further I tried to implement PPO for a comparatively simpler problem so that I can actually measure the performance.
What is Proximal Policy Optimization?
PPO (Proximal Policy Optimization) trains an agent to model the optimal actions for a given state by leveraging an advantage function. The advantage function is computed using a critic, which evaluates how favorable a specific state and action are relative to the current policy.
Points of Interest
While PPO had many important components that were needed to properly be implemented, I would like to cover some that were most impactful for me.
Log probability instead of Argmax:
One of the first mistakes I made was taking the action with the max probability after the softmax layer. This decision turned out to be flawed as it limited exploration. Using random action based on those probability served very useful in training process.
dist = torch.distributions.Categorical(probs=prob)
action=dist.sample()
log_prob=dist.log_prob(action)
Generalized Advantage Estimation (GAE) instead of TDA
Probably the most important mistake in my implementation was using Temporal Difference Advantage method to calculate the Advantage, this caused high bias and messed with the loss.
for t in reversed(range(T)):
delta = reward_col[t] + self.gamma * value_col[t + 1] * (1 - dones[t]) - value_col[t]
gae = delta + self.gamma * self.lamda * gae * (1 - dones[t])
advantage[t] = gae
return_col[t] = gae + value_col[t]
GAE applied TD Error recursively over multiple timesteps to smoothen the calculation.
My Final Implementation
You can check out my final implementation at https://www.kaggle.com/code/ankitupadhyay12/ppo-cart, It's definitely not perfect but served as a good jumping off point. Feel free to point out any mistake I might have made 🫠😅
Result
Top comments (0)