The last time to learn Softmax

April 22, 2026

I've been trying to learn online softmax (the core mechanism of FlashAttention) many times. At this time, I'm serious enough to figure out every single detail so that this will be my last time to learn online Softmax.

The algorithm

The safe softmax for a row xRD\mathbf{x} \in \mathbb{R}^{D} is

softmax(x)j=exp(xjm)k=1Dexp(xkm),m=maxkxk.\mathrm{softmax}(\mathbf{x})_j = \frac{\exp(x_j - {m})}{\sum_{k=1}^{D} \exp(x_k - {m})}, \quad {m} = \max_{k} x_k.

The naive implementation takes 3 passes over x\mathbf{x}: one for mm, one for the denominator ss, and one to write the normalized output.

Online softmax fuses the first two passes by maintaining a running maximum m(j)m^{(j)} and a running denominator s(j)s^{(j)}, where the update rule is:

m(j)=max(m(j1),xj),s(j)=s(j1)exp(m(j1)m(j))+exp(xjm(j)).m^{(j)} = \max(m^{(j-1)}, x_j), \qquad s^{(j)} = s^{(j-1)} \cdot \exp(m^{(j-1)} - m^{(j)}) + \exp(x_j - m^{(j)}).

A second pass then writes exp(xjm)/s\exp(x_j - {m}) / {s}.

🤯

The recurrence looks sequential — element jj seems to wait for element j1j-1, but the reducer

(ma,sa)(mb,sb)=(max(ma,mb), saemamc+sbembmc),mc=max(ma,mb)(m_a, s_a) \oplus (m_b, s_b) = \big(\max(m_a, m_b),\ s_a \cdot e^{m_a - m_c} + s_b \cdot e^{m_b - m_c}\big), \quad m_c = \max(m_a, m_b)

is associative: sisjsk=si(sjsk)s_i \oplus s_j \oplus s_k = s_i \oplus (s_j \oplus s_k). That means, if we loop and reduce all the elements along the DD dimension, we can obtain the correct result no matter what order.

FlashAttention 1 extends this idea one step further by fusing the normalization into the output accumulator of the attention matmul, reducing the whole procedure to a single pass.

CuTe DSL implementation

We now port the same algorithm to a single CUDA kernel using the CuTe Python DSL (CUTLASS 4.x). The shape (N,D)=(128,16384)(N, D) = (128, 16384) maps naturally to a 1D grid of NN CTAs, one per row, with each CTA cooperatively reducing its row across T=256T = 256 threads so that every thread owns V=D/T=64V = D / T = 64 elements in registers.

The kernel follows the same 3 structure:

  1. Each thread streams its VV values and maintains a running (m,s)(m, s) pair with the online update rule.
  2. Threads in a CTA exchange their (m,s)(m, s) pairs through shared memory using the same reducer, collapsing to a single row-wise (m,s)(m, s).
  3. A second pass over the VV registered values writes the normalized output.

There is no need to write intermediate results back to gmem.

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch

THREADS = 512


@cute.struct
class SharedStorage:
    smem_m: cute.struct.MemRange[cutlass.Float32, THREADS]
    smem_s: cute.struct.MemRange[cutlass.Float32, THREADS]


@cute.kernel
def online_softmax_kernel(
    mX: cute.Tensor,  # (N, D), row-major
    mY: cute.Tensor,  # (N, D), row-major
    smem_layout: cute.Layout,  # (M,) shared-memory layout
    THREADS: cutlass.Constexpr[int],
    V: cutlass.Constexpr[int],
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()

    gX_row = mX[bidx, None]  # (D,)
    gY_row = mY[bidx, None]  # (D,)

    # Per-thread register fragment of V elements, strided by THREADS.
    rX = cute.make_rmem_tensor(cute.make_layout(V), cutlass.Float32)
    for v in cutlass.range(V):
        rX[v] = gX_row[tidx + v * THREADS]

    # ---- Stage 1: per-thread online (m, s) ----
    m = -cutlass.Float32.inf
    s = cutlass.Float32(0.0)
    for v in cutlass.range(V):
        x_v = rX[v]
        m_new = cute.arch.fmax(m, x_v)
        s = s * cute.math.exp(m - m_new, fastmath=True) + cute.math.exp(
            x_v - m_new, fastmath=True
        )
        m = m_new

    # ---- Stage 2: block-wide combine via shared memory ----
    smem_alloc = cutlass.utils.SmemAllocator()
    storage = smem_alloc.allocate(SharedStorage)
    smem_m = storage.smem_m.get_tensor(smem_layout)
    smem_s = storage.smem_s.get_tensor(smem_layout)
    smem_m[tidx] = m
    smem_s[tidx] = s
    cute.arch.barrier()

    offset = THREADS // 2
    while offset > 0:
        if tidx < offset:
            m_a, m_b = smem_m[tidx], smem_m[tidx + offset]
            s_a, s_b = smem_s[tidx], smem_s[tidx + offset]
            m_c = cute.arch.fmax(m_a, m_b)
            s_c = s_a * cute.math.exp(m_a - m_c, fastmath=True) + s_b * cute.math.exp(
                m_b - m_c, fastmath=True
            )
            smem_m[tidx] = m_c
            smem_s[tidx] = s_c
        cute.arch.barrier()
        offset //= 2

    m_row = smem_m[0]
    s_row = smem_s[0]
    inv_s = cutlass.Float32(1.0) / s_row

    # ---- Stage 3: normalize and write back ----
    for v in cutlass.range(V):
        gY_row[tidx + v * THREADS] = cute.math.exp(rX[v] - m_row, fastmath=True) * inv_s


@cute.jit
def online_softmax(mX: cute.Tensor, mY: cute.Tensor):
    N, D = mX.shape
    V = D // THREADS
    smem_layout = cute.make_layout(THREADS)

    online_softmax_kernel(mX, mY, smem_layout, THREADS, V).launch(
        grid=(N, 1, 1),
        block=(THREADS, 1, 1),
    )

Footnotes

  1. https://arxiv.org/abs/2205.14135