I've been trying to learn online softmax (the core mechanism of FlashAttention) many times. At this time, I'm serious enough to figure out every single detail so that this will be my last time to learn online Softmax.
The algorithm
The safe softmax for a row is
The naive implementation takes 3 passes over : one for , one for the denominator , and one to write the normalized output.
Online softmax fuses the first two passes by maintaining a running maximum and a running denominator , where the update rule is:
A second pass then writes .
The recurrence looks sequential — element seems to wait for element , but the reducer
is associative: . That means, if we loop and reduce all the elements along the dimension, we can obtain the correct result no matter what order.
FlashAttention 1 extends this idea one step further by fusing the normalization into the output accumulator of the attention matmul, reducing the whole procedure to a single pass.
CuTe DSL implementation
We now port the same algorithm to a single CUDA kernel using the CuTe Python DSL (CUTLASS 4.x). The shape maps naturally to a 1D grid of CTAs, one per row, with each CTA cooperatively reducing its row across threads so that every thread owns elements in registers.
The kernel follows the same 3 structure:
- Each thread streams its values and maintains a running pair with the online update rule.
- Threads in a CTA exchange their pairs through shared memory using the same reducer, collapsing to a single row-wise .
- A second pass over the registered values writes the normalized output.
There is no need to write intermediate results back to gmem.
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch
THREADS = 512
@cute.struct
class SharedStorage:
smem_m: cute.struct.MemRange[cutlass.Float32, THREADS]
smem_s: cute.struct.MemRange[cutlass.Float32, THREADS]
@cute.kernel
def online_softmax_kernel(
mX: cute.Tensor, # (N, D), row-major
mY: cute.Tensor, # (N, D), row-major
smem_layout: cute.Layout, # (M,) shared-memory layout
THREADS: cutlass.Constexpr[int],
V: cutlass.Constexpr[int],
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
gX_row = mX[bidx, None] # (D,)
gY_row = mY[bidx, None] # (D,)
# Per-thread register fragment of V elements, strided by THREADS.
rX = cute.make_rmem_tensor(cute.make_layout(V), cutlass.Float32)
for v in cutlass.range(V):
rX[v] = gX_row[tidx + v * THREADS]
# ---- Stage 1: per-thread online (m, s) ----
m = -cutlass.Float32.inf
s = cutlass.Float32(0.0)
for v in cutlass.range(V):
x_v = rX[v]
m_new = cute.arch.fmax(m, x_v)
s = s * cute.math.exp(m - m_new, fastmath=True) + cute.math.exp(
x_v - m_new, fastmath=True
)
m = m_new
# ---- Stage 2: block-wide combine via shared memory ----
smem_alloc = cutlass.utils.SmemAllocator()
storage = smem_alloc.allocate(SharedStorage)
smem_m = storage.smem_m.get_tensor(smem_layout)
smem_s = storage.smem_s.get_tensor(smem_layout)
smem_m[tidx] = m
smem_s[tidx] = s
cute.arch.barrier()
offset = THREADS // 2
while offset > 0:
if tidx < offset:
m_a, m_b = smem_m[tidx], smem_m[tidx + offset]
s_a, s_b = smem_s[tidx], smem_s[tidx + offset]
m_c = cute.arch.fmax(m_a, m_b)
s_c = s_a * cute.math.exp(m_a - m_c, fastmath=True) + s_b * cute.math.exp(
m_b - m_c, fastmath=True
)
smem_m[tidx] = m_c
smem_s[tidx] = s_c
cute.arch.barrier()
offset //= 2
m_row = smem_m[0]
s_row = smem_s[0]
inv_s = cutlass.Float32(1.0) / s_row
# ---- Stage 3: normalize and write back ----
for v in cutlass.range(V):
gY_row[tidx + v * THREADS] = cute.math.exp(rX[v] - m_row, fastmath=True) * inv_s
@cute.jit
def online_softmax(mX: cute.Tensor, mY: cute.Tensor):
N, D = mX.shape
V = D // THREADS
smem_layout = cute.make_layout(THREADS)
online_softmax_kernel(mX, mY, smem_layout, THREADS, V).launch(
grid=(N, 1, 1),
block=(THREADS, 1, 1),
)