Policy Gradient and PPO
From the policy gradient theorem to PPO's clipped objective, with interactive equations and diagnostics.
Big picture: Policy gradients push the policy in the direction of higher expected returns, but raw updates can be unstable. PPO wraps the same idea with a trust-region style guardrail that keeps updates within a safe window.
Why Policy Gradients?
Direct optimization: parameterize \(\pi_\theta(a|s)\) and climb the gradient of expected return. No detours through explicit action-value tables.
- Continuous action spaces (robots, trading)
- Stochastic policies where exploration is built-in
- When you care about action distributions, not just argmax decisions
- High-variance gradients can stall learning
- Large steps can wipe out a good policy
- Sample efficiency hinges on solid baselines and critics
From score function to updates
The policy gradient theorem gives a practical recipe:
- Sample trajectories \((s_t, a_t, r_t)\)
- Estimate \(A_t\) (advantage or centered return)
- Update via \(\nabla_\theta \log \pi_\theta(a_t|s_t) A_t\)
Variance reduction through baselines is what makes this workable in practice.
Policy Gradient Theorem
Variance countermeasures
Subtract baselines or estimate advantages so the gradient stays well-behaved. Experiment below to see how variance changes with the strength of a learned baseline.
REINFORCE (Monte Carlo Policy Gradient)
Classic REINFORCE uses complete returns \(G_t\) as the learning signal. It is unbiased but notoriously noisy.
Subtracting a baseline (typically a value estimate) leaves the gradient unbiased while shrinking variance.
Use the controls to see how the return trace changes with the discount factor. Randomize trajectories to observe how variance fluctuates.
Proximal Policy Optimization (PPO)
Trust region intuition: Keep \(r_t\) within \([1-\epsilon, 1+\epsilon]\). Outside this window the objective plateaus, discouraging oversized steps.
Try sweeping \(\epsilon\). Smaller windows produce conservative updates; larger windows behave more like vanilla policy gradients and risk instability.
Putting It Together
Checklist before training
- Normalize advantages and clip entropy bonuses if exploration collapses.
- Monitor KL divergence; anneal \(\epsilon\) or early-stop if it spikes.
- Use multiple epochs of minibatch updates on the same rollout but keep the “old” policy fixed.
REINFORCE in PyTorch (Skeleton)
import torch import torch.nn as nn import torch.optim as optim class Policy(nn.Module): def __init__(self, obs_dim, act_dim): super().__init__() self.net = nn.Sequential( nn.Linear(obs_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, act_dim) ) def forward(self, x): logits = self.net(x) return torch.distributions.Categorical(logits=logits) def compute_returns(rewards, gamma): G, ret = [], 0.0 for r in reversed(rewards): ret = r + gamma * ret G.append(ret) return list(reversed(G)) def train(env, policy, optimizer, gamma=0.99, episodes=1000): for ep in range(episodes): logps, rewards = [], [] obs, _ = env.reset() done = False while not done: obs_t = torch.as_tensor(obs, dtype=torch.float32) dist = policy(obs_t) action = dist.sample() logps.append(dist.log_prob(action)) obs, reward, terminated, truncated, _ = env.step(action.item()) rewards.append(reward) done = terminated or truncated returns = torch.as_tensor(compute_returns(rewards, gamma), dtype=torch.float32) returns = (returns - returns.mean()) / (returns.std() + 1e-8) # optional norm loss = -(torch.stack(logps) * returns).sum() optimizer.zero_grad(); loss.backward(); optimizer.step()
Further Reading
- Hugging Face Deep RL Course (Unit 4) — approachable walkthrough of policy gradients and PPO.
- Schulman et al. 2017 — original PPO paper with clipped objective and adaptive KL variants.
- OpenAI Spinning Up — detailed notes plus reference implementations.