# The Whole Paper Fits in One Sigmoid: Implementing the SDAR Gate

> Source: <https://dev.to/shoaibalimir/the-whole-paper-fits-in-one-sigmoid-implementing-the-sdar-gate-1f4k>
> Published: 2026-06-14 07:18:57+00:00

Recap.[Part 1]framed the problem (trajectory reward is too coarse for multi-step agents) and SDAR's fix (a privileged teacher gives dense token-level guidance, filtered through a gate).[Part 2]put the four-model system on AWS and counted the GPU cost. This part is the payoff: the actual gate, in PyTorch.

Honest label up front:what follows is areference implementation- faithful to the paper's mechanism, written to be read and reasoned about. It isnota benchmarked run. I have no convergence curves to sell you; I have the machinery. (Part 4 designs the verification that money would buy.)

Strip away the framework scaffolding and SDAR's entire contribution is a weighting coefficient on a distillation loss. Three moves:

`[0, 1]`

.Positive gap (teacher more confident than student → genuine endorsement) → weight near 1 → distill hard. Negative gap (teacher *less* confident → a rejection that might just be noise) → weight near 0 → soften, don't obey. That asymmetry is the whole idea.

For each generated token `t`

, the teacher (with privileged context) and the student each assign a probability to the token that was actually produced. Define the gap as the difference in their log-probabilities:

```
gap_t = log p_teacher(token_t | privileged_context) − log p_student(token_t)
```

Pass it through a sigmoid to get the gate:

```
gate_t = σ(gap_t / τ)
```

`τ`

is a temperature that stops the sigmoid from snapping to a hard 0/1 (more on that in the traps). The distillation signal itself is a per-token KL that pulls the student toward the teacher's full distribution:

```
KL_t = Σ_v  p_teacher(v) · ( log p_teacher(v) − log p_student(v) )
```

And the combined objective:

```
L_total = L_GRPO + λ · mean_t( gate_t · KL_t )
```

RL stays primary. The gated distillation is an auxiliary nudge whose strength per token is set by how much we trust the teacher *on that token*.

Framework-agnostic PyTorch, written to drop into the actor-update step of a `verl-agent`

/OpenRLHF loss function. `student_logits`

come from the policy being trained; `teacher_logits`

come from a frozen, privileged-context forward pass done under `no_grad`

.

``` python
import torch
import torch.nn.functional as F

def gated_distillation_loss(
    student_logits,      # [B, T, V] - requires grad (the policy)
    teacher_logits,      # [B, T, V] - from a no_grad privileged forward pass
    actions,             # [B, T]    - the token ids actually generated
    response_mask,       # [B, T]    - 1 on generated tokens, 0 elsewhere
    tau: float = 1.0,    # gate temperature
):
    student_logp = F.log_softmax(student_logits, dim=-1)   # [B, T, V]
    teacher_logp = F.log_softmax(teacher_logits, dim=-1)   # [B, T, V]

    # --- 1. token-level gap on the realized action ---
    s_tok = student_logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1)  # [B, T]
    t_tok = teacher_logp.gather(-1, actions.unsqueeze(-1)).squeeze(-1)  # [B, T]
    gap = (t_tok - s_tok).detach()        # DETACH: this is a weight, not a loss

    # --- 2. the gate: positive gap -> ~1 (distill), negative -> ~0 (soften) ---
    gate = torch.sigmoid(gap / tau)       # [B, T], in (0, 1), already detached

    # --- 3. per-token forward KL (teacher || student), pulls student toward teacher ---
    teacher_p = teacher_logp.exp()
    kl_per_tok = (teacher_p * (teacher_logp - student_logp)).sum(-1)    # [B, T]

    # --- 4. gate-weighted, masked mean over generated tokens only ---
    weighted = gate * kl_per_tok * response_mask
    return weighted.sum() / response_mask.sum().clamp(min=1.0)
```

And where it joins the primary objective:

```
rl_loss = grpo_policy_loss(...)          # your existing GRPO term + KL-to-reference
distill = gated_distillation_loss(student_logits, teacher_logits,
                                   actions, response_mask, tau=tau)

total_loss = rl_loss + lam * distill     # lam scheduled, see below
total_loss.backward()
```

That's it. The teacher forward pass and the `gather`

are the only real additions to a working GRPO step.

The code above is short. Getting it *right* is where the time goes.

**1. Detach the gap, and run the teacher under no_grad.**

`gap`

keeps its graph, gradients flow into the teacher branch (which should never update) and into the weighting itself, producing bizarre second-order behaviour. `gap.detach()`

plus a `with torch.no_grad():`

around the teacher forward pass. Forget either and you'll spend an evening confused.**2. Mind the KL direction.**

Forward KL `KL(teacher‖student)`

is mode-covering - the student tries to put mass everywhere the teacher does. Reverse KL `KL(student‖teacher)`

is mode-seeking - the student collapses onto the teacher's peak. Distillation usually wants forward KL (the code above). Swapping them silently changes what your agent learns; it won't crash, it'll just quietly behave differently.

**3. Watch the gate saturate.**

If gaps are large in magnitude, `σ`

pins to 0 or 1 and your "soft" gate becomes a hard binary mask - you've thrown away the nuance that justified the sigmoid. The temperature `τ`

is the fix: raise it to keep the gate responsive. Log the gate's distribution during training; if it's bimodal at the extremes, `τ`

is too low.

**4. Soften negatives - don't zero them.**

The reason this is `σ(gap)`

and not `relu(gap)`

or a hard threshold: a teacher rejection might come from bad skill retrieval, not a bad token (this was the whole motivation in Part 1). A sigmoid leaves a small non-zero weight on rejected tokens, so a noisy "no" can't fully erase a token that was actually fine. Zeroing them throws that hedge away.

One more, not a NaN but a stability killer: **schedule λ.** Start it low (or at zero) and warm it up. Let GRPO establish a competent policy first; ramp the distillation in afterward. Cranking

`λ`

from step zero hands control to the teacher's noisiest early signals - which is exactly the naive-GRPO+OPSD instability SDAR exists to avoid.Mapping back to [Part 2](https://dev.to/shoaibalimir/four-models-in-one-training-loop-architecting-sdar-on-aws-before-renting-a-single-gpu-46ig)'s system:

`p4d`

/`p5`

).`λ`

schedule position. Spot gives a ~2-minute reclaim notice; you want a job that resumes mid-schedule, not one that restarts from `λ=0`

and wastes the warm-up you already paid for.## Optional: the near-free "it runs" experiment

If you want a single screenshot of the gate behaving - not convergence, just proof the plumbing is sound - you can do it on free compute:

Colab/Kaggle free tier(one T4, 16 GB),Qwen2.5-0.5B+ LoRA, a handful of ALFWorld episodes.- Goal: the loss doesn't NaN, and the gate-value histogram shifts as the student learns. That's it.
- It will almost certainly
notlearn the task at 0.5B on a toy slice - and that's fine. You're validating the mechanism, not the result. Label any plot from this as toy-scale, or a commenter rightly will.

We have the mechanism. What we don't have - by design and by budget - is proof it beats the baselines and proof it's more *stable* than naive GRPO+OPSD, which is SDAR's real selling point. Part 4 designs exactly that: the three-way comparison, the stability instrumentation most reproductions skip, and the FinOps reality of running it for real.

*Next: "Evaluation, Stability & FinOps" - how you'd prove the gate earns its keep, and what proving it costs.*

*If you've implemented gated or weighted distillation losses, I'd genuinely like to know how you handled the detach boundary and the KL direction - comments are open.*
