Recap.[Part 1]framed the problem (trajectory reward is too coarse for multi-step agents) and SDAR's fix (a privileged teacher gives dense token-level guidance, filtered through a gate).[Part 2]put the four-model system on AWS and counted the GPU cost. This part is the payoff: the actual gate, in PyTorch.
Honest label up front:what follows is areference implementation- faithful to the paper's mechanism, written to be read and reasoned about. It isnota benchmarked run. I have no convergence curves to sell you; I have the machinery. (Part 4 designs the verification that money would buy.)
Strip away the framework scaffolding and SDAR's entire contribution is a weighting coefficient on a distillation loss. Three moves:
[0, 1]
.Positive gap (teacher more confident than student → genuine endorsement) → weight near 1 → distill hard. Negative gap (teacher less confident → a rejection that might just be noise) → weight near 0 → soften, don't obey. That asymmetry is the whole idea.
For each generated token t
, the teacher (with privileged context) and the student each assign a probability to the token that was actually produced. Define the gap as the difference in their log-probabilities:
gap_t = log p_teacher(token_t | privileged_context) − log p_student(token_t)
Pass it through a sigmoid to get the gate:
gate_t = σ(gap_t / τ)
τ
is a temperature that stops the sigmoid from snapping to a hard 0/1 (more on that in the traps). The distillation signal itself is a per-token KL that pulls the student toward the teacher's full distribution:
KL_t = Σ_v p_teacher(v) · ( log p_teacher(v) − log p_student(v) )
And the combined objective:
L_total = L_GRPO + λ · mean_t( gate_t · KL_t )
RL stays primary. The gated distillation is an auxiliary nudge whose strength per token is set by how much we trust the teacher on that token.
Framework-agnostic PyTorch, written to drop into the actor-update step of a verl-agent
/OpenRLHF loss function. student_logits
come from the policy being trained; teacher_logits
come from a frozen, privileged-context forward pass done under no_grad
.
import torch
import torch.nn.functional as F
def gated_distillation_loss(
student_logits, # [B, T, V] - requires grad (the policy)
teacher_logits, # [B, T, V] - from a no_grad privileged forward pass
actions, # [B, T] - the token ids actually generated
response_mask, # [B, T] - 1 on generated tokens, 0 elsewhere
tau: float = 1.0, # gate temperature
):
student_logp = F.log_softmax(student_logits, dim=-1) # [B, T, V]
teacher_logp = F.log_softmax(teacher_logits, dim=-1) # [B, T, V]
s_tok = student_logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1) # [B, T]
t_tok = teacher_logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1) # [B, T]
gap = (t_tok - s_tok).detach() # DETACH: this is a weight, not a loss
gate = torch.sigmoid(gap / tau) # [B, T], in (0, 1), already detached
teacher_p = teacher_logp.exp()
kl_per_tok = (teacher_p * (teacher_logp - student_logp)).sum(-1) # [B, T]
weighted = gate * kl_per_tok * response_mask
return weighted.sum() / response_mask.sum().clamp(min=1.0)
And where it joins the primary objective:
rl_loss = grpo_policy_loss(...) # your existing GRPO term + KL-to-reference
distill = gated_distillation_loss(student_logits, teacher_logits,
actions, response_mask, tau=tau)
total_loss = rl_loss + lam * distill # lam scheduled, see below
total_loss.backward()
That's it. The teacher forward pass and the gather
are the only real additions to a working GRPO step.
The code above is short. Getting it right is where the time goes.
1. Detach the gap, and run the teacher under no_grad.
gap
keeps its graph, gradients flow into the teacher branch (which should never update) and into the weighting itself, producing bizarre second-order behaviour. gap.detach()
plus a with torch.no_grad():
around the teacher forward pass. Forget either and you'll spend an evening confused.2. Mind the KL direction.
Forward KL KL(teacher‖student)
is mode-covering - the student tries to put mass everywhere the teacher does. Reverse KL KL(student‖teacher)
is mode-seeking - the student collapses onto the teacher's peak. Distillation usually wants forward KL (the code above). Swapping them silently changes what your agent learns; it won't crash, it'll just quietly behave differently.
3. Watch the gate saturate.
If gaps are large in magnitude, σ
pins to 0 or 1 and your "soft" gate becomes a hard binary mask - you've thrown away the nuance that justified the sigmoid. The temperature τ
is the fix: raise it to keep the gate responsive. Log the gate's distribution during training; if it's bimodal at the extremes, τ
is too low.
4. Soften negatives - don't zero them.
The reason this is σ(gap)
and not relu(gap)
or a hard threshold: a teacher rejection might come from bad skill retrieval, not a bad token (this was the whole motivation in Part 1). A sigmoid leaves a small non-zero weight on rejected tokens, so a noisy "no" can't fully erase a token that was actually fine. Zeroing them throws that hedge away.
One more, not a NaN but a stability killer: schedule λ. Start it low (or at zero) and warm it up. Let GRPO establish a competent policy first; ramp the distillation in afterward. Cranking
λ
from step zero hands control to the teacher's noisiest early signals - which is exactly the naive-GRPO+OPSD instability SDAR exists to avoid.Mapping back to Part 2's system:
p4d
/p5
).λ
schedule position. Spot gives a ~2-minute reclaim notice; you want a job that resumes mid-schedule, not one that restarts from λ=0
and wastes the warm-up you already paid for.## Optional: the near-free "it runs" experiment
If you want a single screenshot of the gate behaving - not convergence, just proof the plumbing is sound - you can do it on free compute:
Colab/Kaggle free tier(one T4, 16 GB),Qwen2.5-0.5B+ LoRA, a handful of ALFWorld episodes.- Goal: the loss doesn't NaN, and the gate-value histogram shifts as the student learns. That's it.
- It will almost certainly notlearn the task at 0.5B on a toy slice - and that's fine. You're validating the mechanism, not the result. Label any plot from this as toy-scale, or a commenter rightly will.
We have the mechanism. What we don't have - by design and by budget - is proof it beats the baselines and proof it's more stable than naive GRPO+OPSD, which is SDAR's real selling point. Part 4 designs exactly that: the three-way comparison, the stability instrumentation most reproductions skip, and the FinOps reality of running it for real.
Next: "Evaluation, Stability & FinOps" - how you'd prove the gate earns its keep, and what proving it costs.
If you've implemented gated or weighted distillation losses, I'd genuinely like to know how you handled the detach boundary and the KL direction - comments are open.