Gated Delta Network (GDN) 1 is a linear attention (LA) variant that harnesses both gating for memory control and delta update rule for precise memory modifications.
It was used in the incredible Qwen 3 Next: for every 3 GDN layers, there will be 1 full attention layer (linear:full ratio 3:1).
GDN workflow
In this post, I'm going to walk through the algorithm design and the corresponding hardware parallel optimizations in Megatron.
Variants of linear attention
Linear attention maintains a matrix-valued state St∈Rdv×dk that acts as a key-value associative memory.
Each step absorbs the current key and value, then emits an output via ot=Stqt.
Let's first talk about how different variants differ in how St is updated.
Mamba22 introduces a data-dependent scalar gate αt∈(0,1) that decays the entire state before each write:
St=αtSt−1+vtkt⊤
This forgets bulk context cheaply, but it cannot remove a single key-value pair without also decaying every other association at the same rate.
Unrolling the recurrence shows that St is a weighted sum of past key-value outer products, with each historical contribution multiplied by the product of all gates emitted since:
St=∑i=1t(∏j=i+1tαj)viki⊤
Define the cumulative decay γt=∏i=1tαi.
Then ∏j=i+1tαj=γt/γi, so the state and output collapse to
Stacking the per-token vectors into row matrices Q,K,V∈RL×⋅ gives the parallel matrix form
O=((QK⊤)⊙Γ)V
where Γ∈RL×L is the decay-aware causal mask with Γij=γi/γj for i≥j and 0 otherwise.
Compared with the standard causal mask M of vanilla linear attention, Γ just replaces each 1 entry with the appropriate ratio of cumulative gates, so token i sees token j≤i scaled by γi/γj — exactly the surviving fraction of α products between them.
This is the form Mamba2 trains in: a single fused matmul per layer, no token-by-token iteration, and the same Γ structure is what GDN later reuses inside each chunk.
DeltaNet3 instead applies a generalized Householder term to overwrite one slot at a time, with writing strength βt∈(0,1):
St=St−1(I−βtktkt⊤)+βtvtkt⊤
The (I−βtktkt⊤) factor subtracts the value currently associated with kt before the new vt is written, giving precise edits but no bulk-clear mechanism.
The gated delta rule combines both into a single transition:
αt→0 wipes the state regardless of βt (Mamba2-style hard reset).
αt→1 with βt→1 falls back to the pure delta rule.
Chunkwise parallel training
Running the recurrence token-by-token is memory-bound and leaves tensor cores idle.
GDN trains chunkwise: split the sequence into chunks of size C, propagate S[t] between chunks, and express the work inside a chunk as dense matmuls.
Mathematical details: how to convert GDN into a parallel form
Partially expanding the GDN recurrence over r steps inside chunk [t] splits the running state into a gated transition product F[t]r and a gated accumulated-write sum G[t]r:
Each α[t]i is a scalar, so it factors out of the matrix products.
Pulling the cumulative gate γ[t]r=∏iα[t]i out of the first term gives F[t]r=γ[t]rP[t]r, where P[t]r is the β-only Householder product:
For G[t]r the inner α[t]j's collapse to ratios γ[t]r/γ[t]i that scale the i-th historical write — exactly the entries of the decay-aware mask Γ[t] that re-enter at the matrix level below.
We first derive the WY representation of the β-only building blocks P[t]r and H[t]r — these are DeltaNet's results, then re-insert γ factors to recover F[t]r and G[t]r.
Let says, we want to prove that:
The rest of this section derives (2) and (3) from (1), then turns the per-step recursion into a single C×C triangular solve.
Drop the [t] subscript for readability.
Folding Pr into the W form
Base case (r=1). P1=I−β1k1k1⊤ already matches the WY form with w1=β1k1, consistent with the r=1 instance of (2) (the inner sum is empty).
Inductive step. Assume Pr−1=I−∑i=1r−1wiki⊤.
Following the left-to-right product convention of (1):
Pr=Pr−1(I−βrkrkr⊤)=Pr−1−βr(Pr−1kr)kr⊤
The only nontrivial piece is the vector Pr−1kr∈Rdk:
Pr−1kr=kr−∑i=1r−1wi(ki⊤kr)
which is exactly the bracketed expression inside the definition of wr in (2).
Defining
wr:=βr(kr−∑i=1r−1wi(ki⊤kr))
reduces the previous line to Pr=Pr−1−wrkr⊤, and substituting the inductive form for Pr−1 yields
Pr=I−∑i=1rwiki⊤
closing the induction.
Folding Hr into the U form
The accumulated-write term Hr admits its own one-step recurrence.
Peel off the i=r summand and factor the remaining product:
Inductive claim.Hr=∑i=1ruiki⊤ with the recursion from (3).
For r=1, u1=β1v1 (empty inner sum), so u1k1⊤=β1v1k1⊤=H1.
Assuming the claim at r−1, we have:
Hr−1kr=∑i=1r−1ui(ki⊤kr)
so the bracketed term in the one-step recurrence becomes
βr(vr−Hr−1kr)=βr(vr−∑i=1r−1ui(ki⊤kr))=:ur
which matches the definition of ur in (3).
Substituting back gives Hr=Hr−1+urkr⊤=∑i=1ruiki⊤, closing the induction.
Intuition.ur is the r-th write after the previously stored writes are projected away along the new key direction.
Compared with the W recursion, the only change is the source vector: kr for wr (because we're tracking the transition matrix), vr for ur (because we're tracking the accumulated content).
From sequential recursion to a triangular solve: W=TK
The recursion in (2) looks inherently sequential — wr depends on every earlier wi.
Stacking the per-step vectors row-wise into W,K∈RC×dk (so Wr,:=wr⊤, Kr,:=kr⊤) reveals that it is actually a single triangular linear system in disguise.
Step 1 — transpose into row form. Transposing the wr recursion in (2):
wr⊤=βrkr⊤−βr∑i=1r−1(ki⊤kr)wi⊤
so row r of W satisfies
Wr,:=βrKr,:−βri=1∑r−1(ki⊤kr)Wi,:(4)
Step 2 — identify coefficients with entries of diag(β)KK⊤. Define A:=diag(β)KK⊤. For any r,i:
Ar,i=βr(KK⊤)r,i=βrkr⊤ki=βrki⊤kr
(the last equality uses that ki⊤kr is a scalar). So the coefficient βr(ki⊤kr) in (4) is exactly Ar,i.
Step 3 — restrict to the strictly lower-triangular part. The sum in (4) runs only over i=1,…,r−1, never including i≥r.
Equivalently, only the strictly lower-triangular entries of A matter.
Let L:=strictLower(A), with Lr,i=Ar,i for i<r and 0 otherwise.
Then
The exact same argument applied to the ur recursion in (3) replaces K on the right-hand side with V (because ur's "source" vector is vr rather than kr), giving U=TV with the same T.
Cost view. The sequential reading computes w1,…,wC one at a time with C data-dependent steps. The matrix reading (I+L)W=diag(β)K is a single C×C lower-triangular system with dk right-hand sides (one per column of K), solved by a batched forward substitution — tensor-core-friendly, and the same T is reused for both W and U.
Unified matrix view
Stacking the per-step vectors row-wise, (2) and (3) compactly read
So a product of r rank-1 perturbations plus an accumulated rank-r write — together carrying O(r(dk+dv)) degrees of freedom — are expressed by a single batched C×C triangular inverse applied to K[t] and V[t], replacing r sequential Householder-and-write applications with tensor-core-friendly matmuls.
Re-inserting the α gates
So far we have the WY representation of the β-only P[t]r and H[t]r.
Putting the α's back recovers F[t]r and G[t]r from (1).
For F this is a global scalar: F[t]r=γ[t]rP[t]r, so the row-stacked form is F[t]r=γ[t]r(I−W[t]⊤K[t]) with the same W[t].
For G the α's distribute as ratios: the i-th historical write enters with multiplier γ[t]r/γ[t]i, which is exactly the (r,i) entry of the decay-aware causal mask Γ[t].
Tracing through the W-form argument, every βr(ki⊤kr) coefficient picks up an extra γ[t]r/γ[t]i factor, which means K[t]K[t]⊤ inside T is replaced by Γ[t]⊙K[t]K[t]⊤.
The resulting gated UT transform gives the row-stacked accumulated-write matrix directly:
With ⋅ denoting decay of each vector to the first position of the chunk and ⋅ decay to the last position, the cross-chunk recurrence and per-chunk output become
where M is the standard causal mask and W[t] comes from the same UT transform applied to K[t].
Every intra-chunk operation is now matmul-shaped, so the algorithm preserves the gated delta semantics exactly while running tensor-core-bound.
And the inter-chunk computation is still sequantial.
Block compute flow
The block diagram realizes one step of this recurrence.
Let xt be the block input.
Four linear projections fan xt out into (q,k), v, (α,β), and a residual gate g.
The q,k path applies a short causal Conv, then SiLU, then L2 normalization on each head.
The L2 step pins ∥kt∥=1 so that (I−βtktkt⊤) stays well-conditioned after many delta updates.
The v path uses the same Conv then SiLU stack without L2, since values are the content being written rather than the lookup direction.
The α,β path is a plain linear projection; αt uses Mamba2's parameterization and βt a sigmoid so both stay in (0,1).
(qt,kt,vt,αt,βt) feed the gated delta rule and yield ot=Stqt.
ot is RMS-normalized, gated elementwise by SiLU(g) (the residual branch on the right of the diagram), then projected back to the model dimension by the top Linear.
Cost summary vs full SDPA
The chunkwise algorithm replaces SDPA's O(L2) sequence-axis cost with a fixed-size recurrent state plus per-chunk local work.
Let L be the sequence length, d the per-head dimension (taking dk=dv=d for brevity), and C the chunk size (C=64 in fla).
Quantity (per head, per layer)
Full SDPA
GDN chunkwise
Training compute
O(L2d)
O(LCd+Ld2)
Training memory (activations)
O(Ld) with FlashAttention
O(Ld)+O(d2) state
Inference compute per generated token
O(Ld) (scan KV cache)
O(d2) (matvec with state)
Inference memory per generated token
O(Ld), KV cache grows
O(d2), state is fixed
Training. SDPA scales quadratically with L along the sequence axis.
GDN turns the quadratic term into a chunk-local O(LCd) contribution (linear in L once C is fixed) plus a cross-chunk O(Ld2) term from sweeping the d×d state through the L/C chunks.
For L≫max(C,d) the saving in attention-core flops is roughly a factor L/C.
Inference. SDPA must attend to every prior key, so per-token compute and KV-cache memory both grow linearly with L.
GDN compresses everything into the dk×dv matrix St: each new token costs one O(d2) matvec, and the memory footprint never grows with context length.