从策略梯度定理到现代强化学习

2025-11-10

摘要:本讲义系统梳理了策略梯度(Policy Gradient, PG)方法的发展脉络。从经典的策略梯度定理出发,推导了以减少方差为核心的优势函数估计(GAE);进而介绍了工业界主流的 PPO 算法及其在大模型时代的变体(RPO, GRPO)。

符号说明 (Notation)

符号含义
S,A\mathcal{S}, \mathcal{A}状态空间,动作空间
ρ0(s)\rho_0(s)初始状态分布
P(s;s,a)P(s';s,a)当前动作引起的环境状态转移
πθ(a;s)\pi_\theta(a;s)策略决定的动作分布
τ\tau轨迹 (s0,a0,r0,s1,a1,,sT)(s_0, a_0, r_0, s_1, a_1, \dots, s_T)
GtG_ttt 时刻后的折扣回报 l=0γlrt+l\sum_{l=0}^{\infty} \gamma^l r_{t+l}
Vπ(s),Qπ(s,a)V^\pi(s), Q^\pi(s,a)策略 π\pi 下的状态价值函数与动作价值函数
Aπ(s,a)A^\pi(s,a)优势函数 Qπ(s,a)Vπ(s)Q^\pi(s,a) - V^\pi(s)
γ,λ\gamma, \lambda折扣因子 γ[0,1)\gamma \in [0,1),GAE 平滑因子 λ[0,1]\lambda \in [0,1]

1. 策略梯度定理:从测度视角到可实现算法

策略梯度定理是所有基于梯度的强化学习算法的基石。它解决了在未知环境动态 P(ss,a)P(s'|s,a) 下如何优化策略参数的核心难题。

1.1 问题设定与目标函数

考虑一个无限视界的马尔可夫决策过程(MDP)。我们的目标是最大化期望折扣回报:

J(θ)=Eτpθ(τ)[t=0γtr(st,at)]J(\theta) = \mathbb{E}_{\tau \sim p_\theta(\tau)} \left[ \sum_{t=0}^{\infty} \gamma^t r(s_t, a_t) \right]

其中,一条轨迹 τ=(s0,a0,s1,a1,)\tau = (s_0, a_0, s_1, a_1, \dots) 出现的概率密度由环境动态和策略共同决定:

pθ(τ)=ρ0(s0)t=0πθ(atst)P(st+1st,at)p_\theta(\tau) = \rho_0(s_0) \prod_{t=0}^{\infty} \pi_\theta(a_t|s_t) P(s_{t+1}|s_t, a_t)

1.2 对数导数技巧 (Score Function Estimator)

为了计算 J(θ)J(\theta) 关于 θ\theta 的梯度,我们遇到一个困难:梯度算子 θ\nabla_\theta 作用在期望符号 E\mathbb{E} 内部的概率分布 pθp_\theta 上。利用对数导数技巧:

θpθ(τ)=pθ(τ)θlogpθ(τ)\nabla_\theta p_\theta(\tau) = p_\theta(\tau) \nabla_\theta \log p_\theta(\tau)

从而可以将梯度转化为期望形式:

θEτpθ[f(τ)]=θpθ(τ)f(τ)dτ=Eτpθ[f(τ)θlogpθ(τ)]\nabla_\theta \mathbb{E}_{\tau \sim p_\theta}[f(\tau)] = \int \nabla_\theta p_\theta(\tau) f(\tau) d\tau = \mathbb{E}_{\tau \sim p_\theta} [f(\tau) \nabla_\theta \log p_\theta(\tau)]

1.3 定理推导与因果性证明

将目标函数代入上述技巧,我们有:

θJ(θ)=Eτpθ[(t=0γtrt)θlogpθ(τ)]\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta} \left[ \left( \sum_{t=0}^{\infty} \gamma^t r_t \right) \nabla_\theta \log p_\theta(\tau) \right]

展开 logpθ(τ)\log p_\theta(\tau),注意到环境动态 PP 和初始分布 ρ0\rho_0θ\theta 无关,其梯度为零:

θlogpθ(τ)=θ(logρ0(s0)+t=0(logπθ(atst)+logP(st+1st,at)))=t=0θlogπθ(atst)\nabla_\theta \log p_\theta(\tau) = \nabla_\theta \left( \log \rho_0(s_0) + \sum_{t=0}^\infty \left( \log \pi_\theta(a_t|s_t) + \log P(s_{t+1}|s_t, a_t) \right) \right) = \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t)

代回原式:

θJ(θ)=Eτ[(t=0γtrt)(t=0θlogπθ(atst))]\nabla_\theta J(\theta) = \mathbb{E}_{\tau} \left[ \left( \sum_{t'=0}^{\infty} \gamma^{t'} r_{t'} \right) \left( \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) \right) \right]

关键步骤(利用因果性): 在时刻 tt 做出的动作 ata_t 只会影响 tt 时刻及未来的奖励,而不影响过去的奖励。数学上,对于任何 t<tt' < t

Eτ[rtθlogπθ(atst)]=0\mathbb{E}_{\tau} [ r_{t'} \nabla_\theta \log \pi_\theta(a_t|s_t) ] = 0

因此,我们可以将求和项改为从 tt 开始的累积回报 GtG_t

θJ(θ)=Eτpθ[t=0θlogπθ(atst)(t=tγtrt)]=Eτ[t=0γtθlogπθ(atst)Gt]\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim p_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) \left( \sum_{t'=t}^\infty \gamma^{t'} r_{t'} \right) \right] = \mathbb{E}_{\tau} \left[ \sum_{t=0}^\infty \gamma^t \cdot \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot G_t \right]

(注:标准的PG定理通常会包含一个 γt\gamma^t 因子在梯度项前,但在实际实现中常被省略或归入 GtG_t 的定义中,此处采用严谨形式)

进一步减小方差,引入基线 (Baseline) b(st)b(s_t),只要 bb 不依赖于 ata_t,就不会引入偏差。最优的基线选择近似于 Vπ(st)V^\pi(s_t),从而导出了优势函数形式:

θJ(θ)=Eτ[t=0γtθlogπθ(atst)Aπθ(st,at)]\nabla_\theta J(\theta) = \mathbb{E}_{\tau} \left[ \sum_{t=0}^\infty \gamma^t \nabla_\theta \log \pi_\theta(a_t|s_t) A^{\pi_\theta}(s_t, a_t) \right]

1.4 状态访问测度视角

为了便于实际采样,我们将上述关于“轨迹”的期望转换为关于“状态-动作对”的期望。定义折扣状态访问测度 dπ(s)d^\pi(s)

dπ(s)=(1γ)t=0γtP(st=sπ)d^\pi(s) = (1-\gamma) \sum_{t=0}^\infty \gamma^t P(s_t = s | \pi)

则策略梯度可以重写为更紧凑的形式:

θJ(θ)=11γEsdπ,aπθ(s)[θlogπθ(as)Aπθ(s,a)]\nabla_\theta J(\theta) = \frac{1}{1-\gamma} \mathbb{E}_{s \sim d^\pi, a \sim \pi_\theta(\cdot|s)} \left[ \nabla_\theta \log \pi_\theta(a|s) A^{\pi_\theta}(s, a) \right]

这个公式不仅漂亮,而且直接指导了我们的算法设计:只需在环境中采样 (s,a)(s,a) 对,并估计其优势 A(s,a)A(s,a),即可进行梯度上升。

2. 优势估计:平衡偏差与方差的艺术

准确估计优势函数 Aπ(s,a)A^\pi(s,a) 是 PG 算法性能的关键。

2.1 价值函数与 Bellman 方程

定义状态价值 Vπ(s)=Eπ[Gtst=s]V^\pi(s) = \mathbb{E}_\pi [G_t | s_t=s],其满足 Bellman 方程:

Vπ(s)=Eaπ,sP[r(s,a)+γVπ(s)]V^\pi(s) = \mathbb{E}_{a\sim\pi, s'\sim P} [r(s,a) + \gamma V^\pi(s')]

优势函数定义为动作价值相对于平均状态价值的“盈余”:

Aπ(s,a)=Qπ(s,a)Vπ(s)=EsP[r(s,a)+γVπ(s)Vπ(s)]A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) = \mathbb{E}_{s'\sim P} [r(s,a) + \gamma V^\pi(s') - V^\pi(s)]

2.2 TD 残差 (Temporal Difference Error)

在实际中,我们无法直接获得真实的期望。我们用样本来近似。定义 TD 残差 δt\delta_t 为:

δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)

如果 VV 是真实价值函数 VπV^\pi,那么 δt\delta_t 就是优势函数 Aπ(st,at)A^\pi(s_t, a_t) 的无偏估计:

Est+1P[δtst,at]=Qπ(st,at)Vπ(st)=Aπ(st,at)\mathbb{E}_{s_{t+1}\sim P}[\delta_t | s_t, a_t] = Q^\pi(s_t, a_t) - V^\pi(s_t) = A^\pi(s_t, a_t)

2.3 广义优势估计 (GAE)

单步 TD 残差 δt\delta_t 虽然偏差小(如果 V 准确),但如果仅用一步回报,方差可能较大。蒙特卡洛回报 GtV(st)G_t - V(s_t) 方差大但无偏。 GAE 通过 λ\lambda 在二者之间进行权衡。它实际上是多步优势估计的指数加权平均:

A^tGAE(γ,λ)=l=0(γλ)lδt+l=δt+γλA^t+1GAE(γ,λ)\hat{A}_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l} = \delta_t + \gamma \lambda \hat{A}_{t+1}^{GAE(\gamma, \lambda)}

2.4 Critic 学习与伪代码

为了计算 δt\delta_t,我们需要一个 Critic 网络 Vψ(s)V_\psi(s) 来拟合 Vπ(s)V^\pi(s)。通常使用均方误差损失:

LV(ψ)=12Ni=1NVψ(st)Vtarget2L_V(\psi) = \frac{1}{2N} \sum_{i=1}^N \| V_\psi(s_t) - V_{target} \|^2

其中 VtargetV_{target} 通常选择 GtG_t 或者 A^t+Vold(st)\hat{A}_t + V_{old}(s_t)

def compute_gae_and_update_critic(batch, critic_net, optimizer_critic, gamma=0.99, lam=0.95):
    # batch 包含 (s, a, r, s_next, done) 的张量
    with torch.no_grad():
        values = critic_net(batch['s'])
        next_values = critic_net(batch['s_next'])
        # 计算 TD error
        deltas = batch['r'] + gamma * next_values * (1 - batch['done']) - values
        
        # 逆序递归计算 GAE
        advantages = torch.zeros_like(deltas)
        last_gae_lam = 0
        for t in reversed(range(len(deltas))):
            last_gae_lam = deltas[t] + gamma * lam * (1 - batch['done'][t]) * last_gae_lam
            advantages[t] = last_gae_lam
        
        # 计算价值目标 (V_target = A_GAE + V_old)
        returns = advantages + values

    # Critic 更新
    current_values = critic_net(batch['s'])
    critic_loss = 0.5 * ((current_values - returns)**2).mean()
    
    optimizer_critic.zero_grad()
    critic_loss.backward()
    optimizer_critic.step()
    
    return advantages, returns # 返回给 Actor 使用

3. 从 PPO 到 RPO:信赖域与正则化

标准的 PG 方法因为步长难以控制,容易导致策略更新过大而性能崩溃。

3.1 重要性采样与信赖域动机

为了复用旧策略 πold\pi_{old} 采集的数据来更新新策略 πθ\pi_\theta,我们引入重要性采样比率: rt(θ)=πθ(atst)πold(atst)r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{old}(a_t|s_t)}πθ\pi_\thetaπold\pi_{old} 相差不大时,这个比率接近 1。信赖域类方法(如 TRPO)显式约束 DKL(πoldπθ)δD_{KL}(\pi_{old} \| \pi_\theta) \le \delta

3.2 PPO (Proximal Policy Optimization)

PPO 通过裁剪 (Clipping) 简化了 TRPO 的复杂约束,直接在目标函数中惩罚过大的策略更新: LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]L^{CLIP}(\theta) = \mathbb{E}_t \left[ \min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}*t) \right] 这个目标函数是一个下界,确保了策略在优化过程中不会偏离 πold\pi*{old} 太远。

3.3 RPO (Reference Policy Optimization)

在某些场景(尤其是微调大模型时),我们不仅希望每次更新步幅小(PPO 的功能),还希望最终策略不要偏离一个特定的基准策略 (Reference Policy) πref\pi_{ref} 太远(例如预训练好的基础模型)。 RPO 在目标函数中显式加入 KL 正则项: JRPO(θ)=Eπθ[rt]βEs[DKL(πθ(s)πref(s))]J_{RPO}(\theta) = \mathbb{E}*{\pi*\theta} \left[ \sum r_t \right] - \beta \mathbb{E}*{s} [ D*{KL}(\pi_\theta(\cdot|s) | \pi_{ref}(\cdot|s)) ] 这等价于修改了原 MDP 的奖励函数: r~(s,a)=r(s,a)βlogπθ(as)πref(as)\tilde{r}(s,a) = r(s,a) - \beta \log \frac{\pi_\theta(a|s)}{\pi_{ref}(a|s)} 注:PPO 有时也会用 KL 散度来触发早停(Early Stopping),但 RPO 的 KL 通常是作为一个常驻的正则项存在,且参考对象 πref\pi_{ref} 通常是固定的,而 PPO 的参考对象 πold\pi_{old} 是动态更新的。

3.4 PPO/RPO 统一训练范式伪代码

# 假设已收集一批数据 (s, a, r, ...),并计算好 advantages
def ppo_rpo_update(agent, batch, advantages, pi_ref=None, beta=0.1, use_ppo_clip=True):
    # 预计算旧策略的 log 概率 (用于 PPO ratio)
    with torch.no_grad():
        old_log_probs = agent.actor.log_prob(batch['s'], batch['a'])
        if pi_ref is not None:
             ref_log_probs = pi_ref.log_prob(batch['s'], batch['a'])

    for _ in range(K_epochs):
        # 当前策略的 log 概率和分布
        new_log_probs, dist = agent.actor.evaluate(batch['s'], batch['a'])
        
        # --- PPO 核心 ---
        ratio = torch.exp(new_log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - eps, 1.0 + eps) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        
        # --- RPO / KL Penalty 核心 ---
        if pi_ref is not None:
            # 计算当前策略与参考策略的 KL 散度 (这里用简化的基于样本的估计: log(pi) - log(pi_ref))
            # 更严谨的做法是直接计算两个分布的解析 KL
            # kl_penalty = dist.kl_divergence(ref_dist).mean() 
            # 或者在大模型中常用的 token-level 近似:
            kl_approx = new_log_probs - ref_log_probs 
            policy_loss += beta * kl_approx.mean()

        # 梯度更新
        optimizer.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), 1.0)
        optimizer.step()

4. 大模型时代的后训练:RLHF 与 GRPO

在 LLM 场景下,环境变成了“语言生成”,动作是 Token,轨迹是完整的回复序列。

4.1 建模视角的转换

4.2 GRPO (Group Relative Policy Optimization)

在大模型推理等复杂任务中,训练一个准确的 Critic(价值网络)非常困难。GRPO 提出了一种无需 Critic 的方法。 它利用了组内比较的思想:对于同一个提示词 xx,采样一组回答 {y1,y2,,yG}\{y_1, y_2, \dots, y_G\},并计算它们的奖励 {r1,r2,,rG}\{r_1, r_2, \dots, r_G\}。 此时,某个回答 yiy_i 的优势可以近似为它在组内的相对排名或标准化得分: A^(x,yi)=rimean(rj)std(rj)+ϵ\hat{A}(x, y_i) = \frac{r_i - \text{mean}({r_j})}{\text{std}({r_j}) + \epsilon} GRPO 的优势

  1. 免去 Critic:减少了显存占用和训练不稳定性。
  2. 自适应基线:组内均值天然充当了 Baseline,有效降低方差。

4.3 GRPO 最小化实现

def grpo_step(llm_actor, reward_model, prompt_batch, pi_ref, group_size=4, beta=0.1):
    # 1. 采样阶段 (Sampling)
    # 对每个 prompt 重复采样 group_size 次
    prompts_repeated = prompt_batch.repeat_interleave(group_size, dim=0)
    with torch.no_grad():
        responses, old_log_probs = llm_actor.generate(prompts_repeated, return_log_probs=True)
        ref_log_probs = pi_ref.get_log_probs(prompts_repeated, responses)
    
    # 2. 打分阶段 (Scoring)
    raw_rewards = reward_model(prompts_repeated, responses) # shape: [Batch * G]
    # 将奖励 reshape 为 [Batch, Group_Size] 进行组内标准化
    rewards_grouped = raw_rewards.view(-1, group_size)
    mean_rewards = rewards_grouped.mean(dim=1, keepdim=True)
    std_rewards = rewards_grouped.std(dim=1, keepdim=True)
    advantages = (rewards_grouped - mean_rewards) / (std_rewards + 1e-8)
    advantages = advantages.view(-1) # 展平回 [Batch * G]
    
    # 3. 优化阶段 (Optimization) - 通常结合 PPO Clip
    # 这里的实现简化为单次更新,实际中会有 Inner Epochs
    new_log_probs = llm_actor.get_log_probs(prompts_repeated, responses)
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    # GRPO 优势直接用组内相对得分,并结合 KL 惩罚
    kl_penalty = beta * (new_log_probs - ref_log_probs)
    # 注意:有时 KL 惩罚直接加在 reward 里,这里显式写在 loss 中
    
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1-eps, 1+eps) * advantages
    loss = -torch.min(surr1, surr2).mean() + kl_penalty.mean()
    
    loss.backward()
    optimizer.step()

主题: 强化学习, 策略梯度, GRPO