Policy Gradient & PPO
Main Site

Policy Gradient and PPO

From the policy gradient theorem to PPO's clipped objective, with interactive equations and diagnostics.

Big picture: Policy gradients push the policy in the direction of higher expected returns, but raw updates can be unstable. PPO wraps the same idea with a trust-region style guardrail that keeps updates within a safe window.

Why Policy Gradients?

Direct optimization: parameterize \(\pi_\theta(a|s)\) and climb the gradient of expected return. No detours through explicit action-value tables.

Where they shine
  • Continuous action spaces (robots, trading)
  • Stochastic policies where exploration is built-in
  • When you care about action distributions, not just argmax decisions
Pain points
  • High-variance gradients can stall learning
  • Large steps can wipe out a good policy
  • Sample efficiency hinges on solid baselines and critics

From score function to updates

The policy gradient theorem gives a practical recipe:

  1. Sample trajectories \((s_t, a_t, r_t)\)
  2. Estimate \(A_t\) (advantage or centered return)
  3. Update via \(\nabla_\theta \log \pi_\theta(a_t|s_t) A_t\)

Variance reduction through baselines is what makes this workable in practice.

Policy Gradient Theorem

Objective
Maximize \(J(\theta) = \mathbb{E}_{\tau\sim\pi_\theta}\big[\sum_t \gamma^t r_t\big]\). The gradient pushes probability mass toward actions with positive advantage.
Role of \(A_t\)
Advantages center the gradient. When \(A_t > 0\) we increase \(\pi_\theta(a_t|s_t)\); when \(A_t < 0\) we decrease it.

Variance countermeasures

Subtract baselines or estimate advantages so the gradient stays well-behaved. Experiment below to see how variance changes with the strength of a learned baseline.

Baseline weight (\(\beta\)) 0.50

REINFORCE (Monte Carlo Policy Gradient)

Classic REINFORCE uses complete returns \(G_t\) as the learning signal. It is unbiased but notoriously noisy.

Subtracting a baseline (typically a value estimate) leaves the gradient unbiased while shrinking variance.

Use the controls to see how the return trace changes with the discount factor. Randomize trajectories to observe how variance fluctuates.

Discount factor (\(\gamma\)) 0.99

Proximal Policy Optimization (PPO)

Probability ratio
\(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\) compares the new policy to the frozen snapshot used for sampling.

Trust region intuition: Keep \(r_t\) within \([1-\epsilon, 1+\epsilon]\). Outside this window the objective plateaus, discouraging oversized steps.

Green curve: unconstrained objective when \(A_t > 0\).
Orange curve: clipped objective introduces a flat region once the ratio breaches \(1\pm\epsilon\).
Clip parameter \(\epsilon\) 0.20

Try sweeping \(\epsilon\). Smaller windows produce conservative updates; larger windows behave more like vanilla policy gradients and risk instability.

Putting It Together

Advantage \(A_t\)
Tracks how much better an action is relative to your baseline. Generalized advantage estimation (GAE) trades bias for lower variance.
Actor-Critic Loop
The actor nudges the policy using clipped advantages; the critic learns \(V(s)\) so advantages stay centered.

Checklist before training

  • Normalize advantages and clip entropy bonuses if exploration collapses.
  • Monitor KL divergence; anneal \(\epsilon\) or early-stop if it spikes.
  • Use multiple epochs of minibatch updates on the same rollout but keep the “old” policy fixed.

REINFORCE in PyTorch (Skeleton)

import torch
import torch.nn as nn
import torch.optim as optim

class Policy(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU(),
            nn.Linear(64, act_dim)
        )

    def forward(self, x):
        logits = self.net(x)
        return torch.distributions.Categorical(logits=logits)

def compute_returns(rewards, gamma):
    G, ret = [], 0.0
    for r in reversed(rewards):
        ret = r + gamma * ret
        G.append(ret)
    return list(reversed(G))

def train(env, policy, optimizer, gamma=0.99, episodes=1000):
    for ep in range(episodes):
        logps, rewards = [], []
        obs, _ = env.reset()
        done = False
        while not done:
            obs_t = torch.as_tensor(obs, dtype=torch.float32)
            dist = policy(obs_t)
            action = dist.sample()
            logps.append(dist.log_prob(action))
            obs, reward, terminated, truncated, _ = env.step(action.item())
            rewards.append(reward)
            done = terminated or truncated
        returns = torch.as_tensor(compute_returns(rewards, gamma), dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)  # optional norm
        loss = -(torch.stack(logps) * returns).sum()
        optimizer.zero_grad(); loss.backward(); optimizer.step()

Further Reading