Gated Delta Network

May 11, 2026

Gated Delta Network (GDN) 1 is a linear attention (LA) variant that harnesses both gating for memory control and delta update rule for precise memory modifications. It was used in the incredible Qwen 3 Next: for every 3 GDN layers, there will be 1 full attention layer (linear:full ratio 3:1).

GDN workflow

In this post, I'm going to walk through the algorithm design and the corresponding hardware parallel optimizations in Megatron.

Variants of linear attention

Linear attention maintains a matrix-valued state StRdv×dk\mathbf{S}_t \in \mathbb{R}^{d_v \times d_k} that acts as a key-value associative memory. Each step absorbs the current key and value, then emits an output via ot=Stqt\mathbf{o}_t = \mathbf{S}_t \mathbf{q}_t. Let's first talk about how different variants differ in how St\mathbf{S}_t is updated.

Mamba2 2 introduces a data-dependent scalar gate αt(0,1)\alpha_t \in (0, 1) that decays the entire state before each write:

St=αtSt1+vtkt\mathbf{S}_t = \textcolor{blue}{\alpha_t}\, \mathbf{S}_{t-1} + \mathbf{v}_t \mathbf{k}_t^\top

This forgets bulk context cheaply, but it cannot remove a single key-value pair without also decaying every other association at the same rate.

Unrolling the recurrence shows that St\mathbf{S}_t is a weighted sum of past key-value outer products, with each historical contribution multiplied by the product of all gates emitted since:

St=i=1t(j=i+1tαj)viki\mathbf{S}_t = \sum_{i=1}^t \Bigl(\prod_{j=i+1}^t \alpha_j\Bigr)\, \mathbf{v}_i \mathbf{k}_i^\top

Define the cumulative decay γt=i=1tαi\gamma_t = \prod_{i=1}^t \alpha_i. Then j=i+1tαj=γt/γi\prod_{j=i+1}^t \alpha_j = \gamma_t / \gamma_i, so the state and output collapse to

St=i=1tγtγiviki,ot=i=1tγtγi(kiqt)vi\mathbf{S}_t = \sum_{i=1}^t \frac{\gamma_t}{\gamma_i}\, \mathbf{v}_i \mathbf{k}_i^\top, \qquad \mathbf{o}_t = \sum_{i=1}^t \frac{\gamma_t}{\gamma_i}\, (\mathbf{k}_i^\top \mathbf{q}_t)\, \mathbf{v}_i

Stacking the per-token vectors into row matrices Q,K,VRL×\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{L \times \cdot} gives the parallel matrix form

O=((QK)Γ)V\mathbf{O} = \bigl((\mathbf{Q} \mathbf{K}^\top) \odot \boldsymbol{\Gamma}\bigr)\, \mathbf{V}

where ΓRL×L\boldsymbol{\Gamma} \in \mathbb{R}^{L \times L} is the decay-aware causal mask with Γij=γi/γj\boldsymbol{\Gamma}_{ij} = \gamma_i / \gamma_j for iji \geq j and 00 otherwise. Compared with the standard causal mask M\mathbf{M} of vanilla linear attention, Γ\boldsymbol{\Gamma} just replaces each 11 entry with the appropriate ratio of cumulative gates, so token ii sees token jij \leq i scaled by γi/γj\gamma_i / \gamma_j — exactly the surviving fraction of α\alpha products between them. This is the form Mamba2 trains in: a single fused matmul per layer, no token-by-token iteration, and the same Γ\boldsymbol{\Gamma} structure is what GDN later reuses inside each chunk.

DeltaNet 3 instead applies a generalized Householder term to overwrite one slot at a time, with writing strength βt(0,1)\beta_t \in (0, 1):

St=St1(Iβtktkt)+βtvtkt\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I} - \textcolor{blue}{\beta_t}\, \mathbf{k}_t \mathbf{k}_t^\top) + \textcolor{blue}{\beta_t}\, \mathbf{v}_t \mathbf{k}_t^\top

The (Iβtktkt)(\mathbf{I} - \beta_t \mathbf{k}_t \mathbf{k}_t^\top) factor subtracts the value currently associated with kt\mathbf{k}_t before the new vt\mathbf{v}_t is written, giving precise edits but no bulk-clear mechanism.

The gated delta rule combines both into a single transition:

St=St1(αt(Iβtktkt))+βtvtkt,ot=Stqt\mathbf{S}_t = \mathbf{S}_{t-1}\bigl(\textcolor{blue}{\alpha_t}(\mathbf{I} - \textcolor{blue}{\beta_t} \mathbf{k}_t \mathbf{k}_t^\top)\bigr) + \textcolor{blue}{\beta_t} \mathbf{v}_t \mathbf{k}_t^\top, \qquad \mathbf{o}_t = \mathbf{S}_t \mathbf{q}_t

Two corners recover the parents:

  • αt0\alpha_t \to 0 wipes the state regardless of βt\beta_t (Mamba2-style hard reset).
  • αt1\alpha_t \to 1 with βt1\beta_t \to 1 falls back to the pure delta rule.

Chunkwise parallel training

Running the recurrence token-by-token is memory-bound and leaves tensor cores idle. GDN trains chunkwise: split the sequence into chunks of size CC, propagate S[t]\mathbf{S}_{[t]} between chunks, and express the work inside a chunk as dense matmuls.

Mathematical details: how to convert GDN into a parallel form

Partially expanding the GDN recurrence over rr steps inside chunk [t][t] splits the running state into a gated transition product F[t]r\mathbf{F}_{[t]}^r and a gated accumulated-write sum G[t]r\mathbf{G}_{[t]}^r:

S[t]r=S[t]i=1rα[t]i(Iβ[t]ik[t]ik[t]i)=:F[t]r  +  i=1rβ[t]iv[t]ik[t]ij=i+1rα[t]j(Iβ[t]jk[t]jk[t]j)=:G[t]r(1)\mathbf{S}_{[t]}^r = \mathbf{S}_{[t]}\, \underbrace{\prod_{i=1}^r \textcolor{blue}{\alpha_{[t]}^i}\bigl(\mathbf{I} - \beta_{[t]}^i \mathbf{k}_{[t]}^i \mathbf{k}_{[t]}^{i\top}\bigr)}_{=:\, \mathbf{F}_{[t]}^r}\; +\; \underbrace{\sum_{i=1}^r \beta_{[t]}^i \mathbf{v}_{[t]}^i \mathbf{k}_{[t]}^{i\top} \prod_{j=i+1}^r \textcolor{blue}{\alpha_{[t]}^j}\bigl(\mathbf{I} - \beta_{[t]}^j \mathbf{k}_{[t]}^j \mathbf{k}_{[t]}^{j\top}\bigr)}_{=:\, \mathbf{G}_{[t]}^r} \tag{1}

Each α[t]i\alpha_{[t]}^i is a scalar, so it factors out of the matrix products. Pulling the cumulative gate γ[t]r=iα[t]i\gamma_{[t]}^r = \prod_i \alpha_{[t]}^i out of the first term gives F[t]r=γ[t]rP[t]r\mathbf{F}_{[t]}^r = \gamma_{[t]}^r\, \mathbf{P}_{[t]}^r, where P[t]r\mathbf{P}_{[t]}^r is the β\beta-only Householder product:

P[t]r:=i=1r(Iβ[t]ik[t]ik[t]i),H[t]r:=i=1rβ[t]iv[t]ik[t]ij=i+1r(Iβ[t]jk[t]jk[t]j)\mathbf{P}_{[t]}^r := \prod_{i=1}^r \bigl(\mathbf{I} - \beta_{[t]}^i \mathbf{k}_{[t]}^i \mathbf{k}_{[t]}^{i\top}\bigr), \qquad \mathbf{H}_{[t]}^r := \sum_{i=1}^r \beta_{[t]}^i \mathbf{v}_{[t]}^i \mathbf{k}_{[t]}^{i\top} \prod_{j=i+1}^r \bigl(\mathbf{I} - \beta_{[t]}^j \mathbf{k}_{[t]}^j \mathbf{k}_{[t]}^{j\top}\bigr)

For G[t]r\mathbf{G}_{[t]}^r the inner α[t]j\alpha_{[t]}^j's collapse to ratios γ[t]r/γ[t]i\gamma_{[t]}^r / \gamma_{[t]}^i that scale the ii-th historical write — exactly the entries of the decay-aware mask Γ[t]\boldsymbol{\Gamma}_{[t]} that re-enter at the matrix level below. We first derive the WY representation of the β\beta-only building blocks P[t]r\mathbf{P}_{[t]}^r and H[t]r\mathbf{H}_{[t]}^r — these are DeltaNet's results, then re-insert γ\gamma factors to recover F[t]r\mathbf{F}_{[t]}^r and G[t]r\mathbf{G}_{[t]}^r. Let says, we want to prove that:

P[t]r=Ii=1rw[t]ik[t]i,w[t]r=β[t]r(k[t]ri=1r1w[t]i(k[t]ik[t]r))(2)\mathbf{P}_{[t]}^r = \mathbf{I} - \sum_{i=1}^r \mathbf{w}_{[t]}^i \mathbf{k}_{[t]}^{i\top}, \qquad \mathbf{w}_{[t]}^r = \beta_{[t]}^r\Bigl(\mathbf{k}_{[t]}^r - \sum_{i=1}^{r-1} \mathbf{w}_{[t]}^i (\mathbf{k}_{[t]}^{i\top}\mathbf{k}_{[t]}^r)\Bigr) \tag{2}H[t]r=i=1ru[t]ik[t]i,u[t]r=β[t]r(v[t]ri=1r1u[t]i(k[t]ik[t]r))(3)\mathbf{H}_{[t]}^r = \sum_{i=1}^r \mathbf{u}_{[t]}^i \mathbf{k}_{[t]}^{i\top}, \qquad \mathbf{u}_{[t]}^r = \beta_{[t]}^r\Bigl(\mathbf{v}_{[t]}^r - \sum_{i=1}^{r-1} \mathbf{u}_{[t]}^i (\mathbf{k}_{[t]}^{i\top}\mathbf{k}_{[t]}^r)\Bigr) \tag{3}

The rest of this section derives (2) and (3) from (1), then turns the per-step recursion into a single C×CC \times C triangular solve. Drop the [t][t] subscript for readability.

Folding Pr\mathbf{P}^r into the W form

Base case (r=1r=1). P1=Iβ1k1k1\mathbf{P}^1 = \mathbf{I} - \beta^1 \mathbf{k}^1 \mathbf{k}^{1\top} already matches the WY form with w1=β1k1\mathbf{w}^1 = \beta^1 \mathbf{k}^1, consistent with the r=1r=1 instance of (2) (the inner sum is empty).

Inductive step. Assume Pr1=Ii=1r1wiki\mathbf{P}^{r-1} = \mathbf{I} - \sum_{i=1}^{r-1} \mathbf{w}^i \mathbf{k}^{i\top}. Following the left-to-right product convention of (1):

Pr=Pr1(Iβrkrkr)=Pr1βr(Pr1kr)kr\mathbf{P}^r = \mathbf{P}^{r-1}\bigl(\mathbf{I} - \beta^r \mathbf{k}^r \mathbf{k}^{r\top}\bigr) = \mathbf{P}^{r-1} - \beta^r \bigl(\mathbf{P}^{r-1} \mathbf{k}^r\bigr)\, \mathbf{k}^{r\top}

The only nontrivial piece is the vector Pr1krRdk\mathbf{P}^{r-1} \mathbf{k}^r \in \mathbb{R}^{d_k}:

Pr1kr=kri=1r1wi(kikr)\mathbf{P}^{r-1} \mathbf{k}^r = \mathbf{k}^r - \sum_{i=1}^{r-1} \mathbf{w}^i \bigl(\mathbf{k}^{i\top} \mathbf{k}^r\bigr)

which is exactly the bracketed expression inside the definition of wr\mathbf{w}^r in (2). Defining

wr:=βr(kri=1r1wi(kikr))\mathbf{w}^r := \beta^r \Bigl(\mathbf{k}^r - \sum_{i=1}^{r-1} \mathbf{w}^i (\mathbf{k}^{i\top} \mathbf{k}^r)\Bigr)

reduces the previous line to Pr=Pr1wrkr\mathbf{P}^r = \mathbf{P}^{r-1} - \mathbf{w}^r \mathbf{k}^{r\top}, and substituting the inductive form for Pr1\mathbf{P}^{r-1} yields

Pr=Ii=1rwiki\mathbf{P}^r = \mathbf{I} - \sum_{i=1}^{r} \mathbf{w}^i \mathbf{k}^{i\top}

closing the induction.

Folding Hr\mathbf{H}^r into the U form

The accumulated-write term Hr\mathbf{H}^r admits its own one-step recurrence. Peel off the i=ri = r summand and factor the remaining product:

Hr=i=1r1βivikij=i+1r1(Iβjkjkj)=Hr1(Iβrkrkr)+βrvrkr\mathbf{H}^r = \underbrace{\sum_{i=1}^{r-1} \beta^i \mathbf{v}^i \mathbf{k}^{i\top} \prod_{j=i+1}^{r-1}\bigl(\mathbf{I} - \beta^j \mathbf{k}^j \mathbf{k}^{j\top}\bigr)}_{= \mathbf{H}^{r-1}}\, \bigl(\mathbf{I} - \beta^r \mathbf{k}^r \mathbf{k}^{r\top}\bigr) + \beta^r \mathbf{v}^r \mathbf{k}^{r\top}

Expanding the right factor and rearranging:

Hr=Hr1+βr(vrHr1kr)kr\mathbf{H}^r = \mathbf{H}^{r-1} + \beta^r \bigl(\mathbf{v}^r - \mathbf{H}^{r-1} \mathbf{k}^r\bigr)\, \mathbf{k}^{r\top}

Inductive claim. Hr=i=1ruiki\mathbf{H}^r = \sum_{i=1}^r \mathbf{u}^i \mathbf{k}^{i\top} with the recursion from (3). For r=1r=1, u1=β1v1\mathbf{u}^1 = \beta^1 \mathbf{v}^1 (empty inner sum), so u1k1=β1v1k1=H1\mathbf{u}^1 \mathbf{k}^{1\top} = \beta^1 \mathbf{v}^1 \mathbf{k}^{1\top} = \mathbf{H}^1.

Assuming the claim at r1r-1, we have:

Hr1kr=i=1r1ui(kikr)\mathbf{H}^{r-1} \mathbf{k}^r = \sum_{i=1}^{r-1} \mathbf{u}^i \bigl(\mathbf{k}^{i\top} \mathbf{k}^r\bigr)

so the bracketed term in the one-step recurrence becomes

βr(vrHr1kr)=βr(vri=1r1ui(kikr))=:ur\beta^r\bigl(\mathbf{v}^r - \mathbf{H}^{r-1} \mathbf{k}^r\bigr) = \beta^r\Bigl(\mathbf{v}^r - \sum_{i=1}^{r-1} \mathbf{u}^i (\mathbf{k}^{i\top} \mathbf{k}^r)\Bigr) =: \mathbf{u}^r

which matches the definition of ur\mathbf{u}^r in (3). Substituting back gives Hr=Hr1+urkr=i=1ruiki\mathbf{H}^r = \mathbf{H}^{r-1} + \mathbf{u}^r \mathbf{k}^{r\top} = \sum_{i=1}^r \mathbf{u}^i \mathbf{k}^{i\top}, closing the induction.

Intuition. ur\mathbf{u}^r is the rr-th write after the previously stored writes are projected away along the new key direction. Compared with the W recursion, the only change is the source vector: kr\mathbf{k}^r for wr\mathbf{w}^r (because we're tracking the transition matrix), vr\mathbf{v}^r for ur\mathbf{u}^r (because we're tracking the accumulated content).

From sequential recursion to a triangular solve: W=TK\mathbf{W} = \mathbf{T}\mathbf{K}

The recursion in (2) looks inherently sequential — wr\mathbf{w}^r depends on every earlier wi\mathbf{w}^i. Stacking the per-step vectors row-wise into W,KRC×dk\mathbf{W}, \mathbf{K} \in \mathbb{R}^{C \times d_k} (so Wr,:=wr\mathbf{W}_{r,:} = \mathbf{w}^{r\top}, Kr,:=kr\mathbf{K}_{r,:} = \mathbf{k}^{r\top}) reveals that it is actually a single triangular linear system in disguise.

Step 1 — transpose into row form. Transposing the wr\mathbf{w}^r recursion in (2):

wr=βrkrβri=1r1(kikr)wi\mathbf{w}^{r\top} = \beta^r \mathbf{k}^{r\top} - \beta^r \sum_{i=1}^{r-1} \bigl(\mathbf{k}^{i\top}\mathbf{k}^r\bigr)\, \mathbf{w}^{i\top}

so row rr of W\mathbf{W} satisfies

Wr,:=βrKr,:βri=1r1(kikr)Wi,:(4)\mathbf{W}_{r,:} = \beta^r \mathbf{K}_{r,:} - \beta^r \sum_{i=1}^{r-1} \bigl(\mathbf{k}^{i\top}\mathbf{k}^r\bigr)\, \mathbf{W}_{i,:} \tag{4}

Step 2 — identify coefficients with entries of diag(β)KK\mathrm{diag}(\beta)\, \mathbf{K}\mathbf{K}^\top. Define A:=diag(β)KK\mathbf{A} := \mathrm{diag}(\beta)\, \mathbf{K}\mathbf{K}^\top. For any r,ir, i:

Ar,i=βr(KK)r,i=βrkrki=βrkikr\mathbf{A}_{r,i} = \beta^r\, (\mathbf{K}\mathbf{K}^\top)_{r,i} = \beta^r\, \mathbf{k}^{r\top}\mathbf{k}^i = \beta^r\, \mathbf{k}^{i\top}\mathbf{k}^r

(the last equality uses that kikr\mathbf{k}^{i\top}\mathbf{k}^r is a scalar). So the coefficient βr(kikr)\beta^r (\mathbf{k}^{i\top}\mathbf{k}^r) in (4) is exactly Ar,i\mathbf{A}_{r,i}.

Step 3 — restrict to the strictly lower-triangular part. The sum in (4) runs only over i=1,,r1i = 1, \ldots, r-1, never including iri \geq r. Equivalently, only the strictly lower-triangular entries of A\mathbf{A} matter. Let L:=strictLower(A)\mathbf{L} := \mathrm{strictLower}(\mathbf{A}), with Lr,i=Ar,i\mathbf{L}_{r,i} = \mathbf{A}_{r,i} for i<ri < r and 00 otherwise. Then

i=1r1Ar,iWi,:=i=1CLr,iWi,:=(LW)r,:\sum_{i=1}^{r-1} \mathbf{A}_{r,i}\, \mathbf{W}_{i,:} = \sum_{i=1}^{C} \mathbf{L}_{r,i}\, \mathbf{W}_{i,:} = (\mathbf{L}\mathbf{W})_{r,:}

and (4) becomes

Wr,:+(LW)r,:=βrKr,:=(diag(β)K)r,:\mathbf{W}_{r,:} + (\mathbf{L}\mathbf{W})_{r,:} = \beta^r \mathbf{K}_{r,:} = (\mathrm{diag}(\beta)\, \mathbf{K})_{r,:}

Step 4 — stack rows. Across all r=1,,Cr = 1, \ldots, C:

(I+L)W=diag(β)K(\mathbf{I} + \mathbf{L})\, \mathbf{W} = \mathrm{diag}(\beta)\, \mathbf{K}

Step 5 — invert. I+L\mathbf{I} + \mathbf{L} is lower triangular with unit diagonal, hence invertible, giving the closed form

W=(I+L)1diag(β)K=[I+strictLower(diag(β)KK)]1diag(β)=TK\mathbf{W} = (\mathbf{I} + \mathbf{L})^{-1}\, \mathrm{diag}(\beta)\, \mathbf{K} = \underbrace{\bigl[\mathbf{I} + \mathrm{strictLower}\bigl(\mathrm{diag}(\beta)\, \mathbf{K}\mathbf{K}^\top\bigr)\bigr]^{-1} \mathrm{diag}(\beta)}_{=\,\mathbf{T}}\, \mathbf{K}

That is W=TK\mathbf{W} = \mathbf{T}\,\mathbf{K}.

The exact same argument applied to the ur\mathbf{u}^r recursion in (3) replaces K\mathbf{K} on the right-hand side with V\mathbf{V} (because ur\mathbf{u}^r's "source" vector is vr\mathbf{v}^r rather than kr\mathbf{k}^r), giving U=TV\mathbf{U} = \mathbf{T}\,\mathbf{V} with the same T\mathbf{T}.

Cost view. The sequential reading computes w1,,wC\mathbf{w}^1, \ldots, \mathbf{w}^C one at a time with CC data-dependent steps. The matrix reading (I+L)W=diag(β)K(\mathbf{I} + \mathbf{L})\mathbf{W} = \mathrm{diag}(\beta)\,\mathbf{K} is a single C×CC \times C lower-triangular system with dkd_k right-hand sides (one per column of K\mathbf{K}), solved by a batched forward substitution — tensor-core-friendly, and the same T\mathbf{T} is reused for both W\mathbf{W} and U\mathbf{U}.

Unified matrix view

Stacking the per-step vectors row-wise, (2) and (3) compactly read

P[t]r=IW[t]K[t],H[t]r=U[t]K[t]\mathbf{P}_{[t]}^r = \mathbf{I} - \mathbf{W}_{[t]}^\top \mathbf{K}_{[t]}, \qquad \mathbf{H}_{[t]}^r = \mathbf{U}_{[t]}^\top \mathbf{K}_{[t]}

with the UT transform

T[t]=[I+strictLower(diag(β[t])K[t]K[t])]1diag(β[t])\mathbf{T}_{[t]} = \bigl[\mathbf{I} + \mathrm{strictLower}\bigl(\mathrm{diag}(\beta_{[t]})\, \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top\bigr)\bigr]^{-1} \mathrm{diag}(\beta_{[t]})

W[t]=T[t]K[t],U[t]=T[t]V[t]\mathbf{W}_{[t]} = \mathbf{T}_{[t]}\, \mathbf{K}_{[t]}, \qquad \mathbf{U}_{[t]} = \mathbf{T}_{[t]}\, \mathbf{V}_{[t]}

So a product of rr rank-1 perturbations plus an accumulated rank-rr write — together carrying O(r(dk+dv))O(r(d_k + d_v)) degrees of freedom — are expressed by a single batched C×CC \times C triangular inverse applied to K[t]\mathbf{K}_{[t]} and V[t]\mathbf{V}_{[t]}, replacing rr sequential Householder-and-write applications with tensor-core-friendly matmuls.

Re-inserting the α\alpha gates

So far we have the WY representation of the β\beta-only P[t]r\mathbf{P}_{[t]}^r and H[t]r\mathbf{H}_{[t]}^r. Putting the α\alpha's back recovers F[t]r\mathbf{F}_{[t]}^r and G[t]r\mathbf{G}_{[t]}^r from (1). For F\mathbf{F} this is a global scalar: F[t]r=γ[t]rP[t]r\mathbf{F}_{[t]}^r = \gamma_{[t]}^r\, \mathbf{P}_{[t]}^r, so the row-stacked form is F[t]r=γ[t]r(IW[t]K[t])\mathbf{F}_{[t]}^r = \gamma_{[t]}^r(\mathbf{I} - \mathbf{W}_{[t]}^\top \mathbf{K}_{[t]}) with the same W[t]\mathbf{W}_{[t]}.

For G\mathbf{G} the α\alpha's distribute as ratios: the ii-th historical write enters with multiplier γ[t]r/γ[t]i\gamma_{[t]}^r / \gamma_{[t]}^i, which is exactly the (r,i)(r, i) entry of the decay-aware causal mask Γ[t]\boldsymbol{\Gamma}_{[t]}. Tracing through the W-form argument, every βr(kikr)\beta^r (\mathbf{k}^{i\top}\mathbf{k}^r) coefficient picks up an extra γ[t]r/γ[t]i\gamma_{[t]}^r/\gamma_{[t]}^i factor, which means K[t]K[t]\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top inside T\mathbf{T} is replaced by Γ[t]K[t]K[t]\boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top. The resulting gated UT transform gives the row-stacked accumulated-write matrix directly:

U~[t]=[I+strictLower(diag(β[t])(Γ[t]K[t]K[t]))]1diag(β[t])V[t]\widetilde{\mathbf{U}}_{[t]} = \Bigl[\mathbf{I} + \mathrm{strictLower}\bigl(\mathrm{diag}(\beta_{[t]})\,(\boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top)\bigr)\Bigr]^{-1} \mathrm{diag}(\beta_{[t]})\, \mathbf{V}_{[t]}

With \overleftarrow{\cdot} denoting decay of each vector to the first position of the chunk and \overrightarrow{\cdot} decay to the last position, the cross-chunk recurrence and per-chunk output become

S[t+1]=S[t]+(U[t]~W[t]S[t])K[t]\mathbf{S}_{[t+1]} = \overrightarrow{\mathbf{S}_{[t]}} + \bigl(\widetilde{\mathbf{U}_{[t]}} - \overleftarrow{\mathbf{W}_{[t]}}\, \mathbf{S}_{[t]}^\top\bigr)^\top \overrightarrow{\mathbf{K}_{[t]}}

O[t]=Q[t]S[t]+(Q[t]K[t]M)(U[t]~W[t]S[t])\mathbf{O}_{[t]} = \overleftarrow{\mathbf{Q}_{[t]}}\, \mathbf{S}_{[t]}^\top + \bigl(\mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}\bigr)\bigl(\widetilde{\mathbf{U}_{[t]}} - \overleftarrow{\mathbf{W}_{[t]}}\, \mathbf{S}_{[t]}^\top\bigr)

where M\mathbf{M} is the standard causal mask and W[t]\mathbf{W}_{[t]} comes from the same UT transform applied to K[t]\mathbf{K}_{[t]}. Every intra-chunk operation is now matmul-shaped, so the algorithm preserves the gated delta semantics exactly while running tensor-core-bound. And the inter-chunk computation is still sequantial.

Block compute flow

The block diagram realizes one step of this recurrence. Let xt\mathbf{x}_t be the block input.

  1. Four linear projections fan xt\mathbf{x}_t out into (q,k)(\mathbf{q}, \mathbf{k}), v\mathbf{v}, (α,β)(\alpha, \beta), and a residual gate g\mathbf{g}.
  2. The q,k\mathbf{q}, \mathbf{k} path applies a short causal Conv, then SiLU, then L2 normalization on each head. The L2 step pins kt=1\|\mathbf{k}_t\| = 1 so that (Iβtktkt)(\mathbf{I} - \beta_t \mathbf{k}_t \mathbf{k}_t^\top) stays well-conditioned after many delta updates.
  3. The v\mathbf{v} path uses the same Conv then SiLU stack without L2, since values are the content being written rather than the lookup direction.
  4. The α,β\alpha, \beta path is a plain linear projection; αt\alpha_t uses Mamba2's parameterization and βt\beta_t a sigmoid so both stay in (0,1)(0, 1).
  5. (qt,kt,vt,αt,βt)(\mathbf{q}_t, \mathbf{k}_t, \mathbf{v}_t, \alpha_t, \beta_t) feed the gated delta rule and yield ot=Stqt\mathbf{o}_t = \mathbf{S}_t \mathbf{q}_t.
  6. ot\mathbf{o}_t is RMS-normalized, gated elementwise by SiLU(g\mathbf{g}) (the residual branch on the right of the diagram), then projected back to the model dimension by the top Linear.

Cost summary vs full SDPA

The chunkwise algorithm replaces SDPA's O(L2)O(L^2) sequence-axis cost with a fixed-size recurrent state plus per-chunk local work. Let LL be the sequence length, dd the per-head dimension (taking dk=dv=dd_k = d_v = d for brevity), and CC the chunk size (C=64C = 64 in fla).

Quantity (per head, per layer)Full SDPAGDN chunkwise
Training computeO(L2d)O(L^2\, d)O(LCd+Ld2)O(L\, C\, d + L\, d^2)
Training memory (activations)O(Ld)O(L\, d) with FlashAttentionO(Ld)+O(d2)O(L\, d) + O(d^2) state
Inference compute per generated tokenO(Ld)O(L\, d) (scan KV cache)O(d2)O(d^2) (matvec with state)
Inference memory per generated tokenO(Ld)O(L\, d), KV cache growsO(d2)O(d^2), state is fixed

Training. SDPA scales quadratically with LL along the sequence axis. GDN turns the quadratic term into a chunk-local O(LCd)O(L\, C\, d) contribution (linear in LL once CC is fixed) plus a cross-chunk O(Ld2)O(L\, d^2) term from sweeping the d×dd \times d state through the L/CL/C chunks. For Lmax(C,d)L \gg \max(C, d) the saving in attention-core flops is roughly a factor L/CL / C.

Inference. SDPA must attend to every prior key, so per-token compute and KV-cache memory both grow linearly with LL. GDN compresses everything into the dk×dvd_k \times d_v matrix St\mathbf{S}_t: each new token costs one O(d2)O(d^2) matvec, and the memory footprint never grows with context length.

Footnotes

  1. Gated Delta Networks: Improving Mamba2 with Delta Rule, https://arxiv.org/abs/2412.06464v3

  2. Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality, https://arxiv.org/abs/2405.21060

  3. Parallelizing Linear Transformers with the Delta Rule over Sequence Length, https://arxiv.org/abs/2406.06484