Context parallel (CP) is used to mitigate the memory pressure when the length of sequences is beyond one single GPU can take.
In this post, I'm going to share variants about CP: p2p, a2a, hybrid CP as well as dynamic CP.
Ring attention (p2p CP)
For a sequence, we can see that each CP rank holds for tokens in the zig-zag order for load balance. During forward, each rank sends to next neighbor while receiving from previous neighbor, in the ring order. Meanwhile (with computation-communication overlap), each rank also calculate the partial result of using online softmax.
Q: Why transmitting , rather than ?
- The partial output is stored locally and then fed into online softmax for next step
- If we store in each rank and let rotate in-flight, the intermediate results have to be transferred together with
- Attention variants like GQA and MQA use less size for compared to
And this is why we require .
CP attention dispatch in Megatron
The actual ring attention and online softmax happen inside TE.
We now know the online softmax can be used for the masked score calculation in one pass. But in ring attention, there is an extra step to multiply the score and . So how does CP merge the partial result where is encoded? We have:
Similarly, we have . As CP splits the computation along the sequence dimension, it's actually a special (weighted) reduction in rowwise matrix multiplication parallel:
we can maintain via LSE :
thus . Note that this update is associative, therefore we just need to maintain and at each ring step . Let and be the partial LSE and output produced at step , the running update is:
where subtracts the maximum for numerical stability, and with act as the soft-mixing weights on the running accumulator and the new partial.
TE does not implement this recurrence on literally. Only the LSE accumulator is streamed across ring steps; each step's partial and is stashed into a per-step list. After the ring loop finishes, a single post-loop pass folds them in via , which is mathematically equivalent to the streaming form once the final is known (the running rescaling collapses). This may leads to larger runtime memory consumption.
Ulysses 1 (a2a CP)
There are 2 types of layers in the entire LLM model: sequence-level (e.g., attention) and token-level (e.g., MoE, MLP). For those token-level layers, there is no need to keep the entire sequence as no interaction happens between positions within a sequence. Thus, we can split at the sequence dimension to let each CP rank handle different parts of the activation.
In Megatron, attention CP is handled by TE internally:
It forms a 3-stage software defined pipeline:
- local reshape
- output buffer creation + async communication trigger
- wait on communication and post-transformer
The shape manipulation during this process is shown as follow:
Before attention CP pipeline
By default, Megatron CP scheduler organizes the token in a zig-zag order for load balance, but in a2a CP, each rank now owns the full sequence, so TE has to permute again to recover the casual order.
Then after the attention, each CP rank holds the full-sequence and head-sharded tensor, TE needs to convert it back to the sequence-shard format:
CP software pipeline code in TE
This time, it's another 3-stage software pipeline conducting the steps:
- reorder to the load-balanced token order
- async collective
- post-processing, including reshape
After attention CP pipeline
Hybrid (hierarchical) CP 2
In MBridge, you can enable the a2a+p2p CP via the following config:
cfg.model.context_parallel_size = 8
cfg.model.cp_comm_type = "a2a+p2p"
cfg.model.hierarchical_context_parallel_sizes = [4, 2] # prod == 8
cfg.dist.use_decentralized_pg = False # <- required for Bridge
Intrinstically, CP size is bound by the number of KV heads due to the a2a operation in the attention layer.
To exceed such a limit, HCP can be used by combining intra node a2a + inter node p2p so that it can provide larger CP group size.
a2a happens within inner-CP ranks with faster GPU interconnections, and p2p runs on inter-node ranks with IB connected.
Hybrid CP with inner CP size = 4 and outer CP size = 2
We use icp and ocp to denote the inner and outer CP size, thus the total CP size equals to icp ocp.
The entire sequence is still split into CP size chunks, where each a2a domain owns icp consecutive tokens in the sequence dimension.
Inside TE, each ICP rank conducts a2a to flip the activation from context-split to head-split status, so that each ICP rank now maintains full sequence but icp head of the activation.
Then before attention, it has to restore the casual order from load balancing order using the index mapping chunk_ids_for_a2a obtained from get_seq_chunk_ids_for_reordering_before_attn.
After inner a2a, each rank now has ocpicp for , and as the attention inputs.
Then TE specifies the send and receive peers during p2p exchanges for and with total ocp ring steps.
Similar to pure p2p CP aforementioned, each rank retains its own split and exchanges / for attention calculation.
Software pipelining code for p2p CP attention in TE
In HCP, the inner CP group undos the load balance and permutes token into causual order (e.g., ) in Fig. HCP.
As for the outer CP group, the order is still kept as zig-zag for load balance.
For example, with ocp set to 2, outer CP rank 0 owns and and the other outer CP rank contains and , while tokens within is in causal order.
allgather CP
There is another CP mode allgather which is probably the simplest among these 4 modes.
For each CP rank, and are Allgathered to produce the full-sequence, full-head attention input.
Besides, the token order is also un-zig-zagged such that and are in causual order.
AG CP shape transform
Attention query is still sharded across CP ranks, which is still under zig-zag order. Therefore, the computation is balanced across CP ranks.
Sequence packing and balance
Traditional sbhd data layout requires all sequence within a batch have to be padded into a fixed length, resulting to FLOPs wastes.
In Megatron, one can convert sbhd into thd layout where t is the total number of tokens within a batch, which is expressed as follows:
cu_seqlens_q(kv)_padded is presented for CUDA graph static input spec.
Combined with CP, the thd tokens are scattered across CP ranks.
TE CUDA kernel
Here is a concrete example (given CP size = 3, and cu_seq_lens = ), the kernel outputs zig-zagged indices for each CP rank:
Rank 0 indices: [ 0, 1, 10, 11, 12, 13, 14, 15, 32, 33, 34, 35, 36, 41]
Rank 1 indices: [ 2, 3, 8, 9, 16, 17, 18, 19, 28, 29, 30, 31, 37, 40]
Rank 2 indices: [ 4, 5, 6, 7, 20, 21, 22, 23, 24, 25, 26, 27, 38, 39]
Positional embeddings
When enabling sequence packing, TE should take care the positional embedding (PE), as the token order has been changed:
We can see under thd format, o_stride_s_or_t is set as and the batch dimension is ignored.
Also, the CUDA kernel maps token positions into freqs with s_id_for_freqs considering both the sequence packing and CP split:
DCP: Dynamic DPCP groups
Even with CP combined with sequence packing, there are still some bottlenecks that hinders the performance 3:
- static CP is pinned to the worst-case sequence in the batch.
- equal pack lengths β equal compute (DP imbalance)
- CP communication stops hiding behind compute when packs are short
To overcome these, dynamic CP (DCP) is proposed with the idea that at each micro batch, the scheduler dynamically choose the best combination of CP size and the composed packed sequence. With multiple CP group with varied CP size (shared with DP group), Megatron can scale the CP domain accommandated with the sequence length distribution.
Currently (2026/05), you can only enable DCP via MCore, you may enable this feature for training via:
torchrun ... pretrain_gpt.py \
--tensor-model-parallel-size $TP \
--context-parallel-size $CP \
--hybrid-context-parallel \
--max-seqlen-per-dp-cp-rank $MAX_SEQ_PER_RANK \
...
--hybrid-context-parallel option here is the fundamantally different feature against the aforementioned HCP.
See the "pitfalls" section in MBridge docs.
The key argument is --max-seqlen-per-dp-cp-rank, which controls the maximal sequence length each DCP rank can receive.
DCP group
When DCP enabled, Megatron creates several DCP groups whose sizes are the power of 2s in range where is the total number of ranks in the DP and CP dimension.
Switching a rank between CP and DP doesn't need to reshard model weights, so we can think it's "free" to make such exchange.
One thing to mention: per-token gradient scaling across the DP-CP group: the final gradients are divided by the total number of tokens in this micro batch across all DP-CP ranks.
Runtime CP size decision: BalancedCPScheduler
Given a batch of samples, max_seq_len_per_rank and the total number of DCP ranks, Megatron develops a scheduler named BalancedCPScheduler to determine:
- how DCP ranks are split into DP and CP groups
- how the batch, with sequences with various lengths, are fed into different DP ranks
Concrete batch (8 ranks, max_seqlen_per_dp_cp_rank = 4K, workload = cp_size, represented by length in following diagram):
| sample | length | CP size | stage | assigned ranks |
|---|---|---|---|---|
| 32K | 8 | 0 | 0..7 | |
| 16K | 4 | 1 | 0..3 | |
| 8K | 2 | 1 | 4..5 | |
| 4K | 1 | 1 | 6 | |
| 4K | 1 | 1 | 6 | |
| 4K | 1 | 1 | 7 | |
| 4K | 1 | 1 | 7 | |
| 2K | 1 | 1 | 6 |
Sequence scheduling goal: pack one global batch's sub-sequences onto the GPU rectangle, balanced and idle-free, wtih each sub-sequence is a rectangle: taller for longer sequences (more GPUs needed), wider for more per-GPU work.
- Bucket samples by size class. Group sub-sequences by how many GPUs they need, with roughly equal total work per bucket. Buckets get processed largest-first.
- Greedy fill. Walk the buckets and place each sub-sequence either into an existing group of the right size or onto fresh free GPUs β whichever leaves the worst-loaded GPU lighter.
- Stop when balanced. Once column heights are close enough, close the round.
- Trim overshoot. If one column ended up tall, peel its last-placed sub-sequence back into the leftover queue if that helps even things out.
- Fill empty GPUs. If any GPUs are still idle, keep doubling the smallest existing group's reach until every GPU has work, sliding neighbors aside as needed.
- Repeat per round. Whatever didn't fit goes into the next scheduling round, separated by a barrier so groups can safely reshape between rounds.
Concepts need to distinct:
- global batch: per
DataLoader.__next__/ per optimizerstep - micro batch: per
forward/backward - schedule round: the number of sync barriers in the DCP ranks
Note that Megatron does NOT provide an optimal scheduling plan for DCP.
Sequence packing with DCP
If you specify the thd format together with DCP, Megatron will split the packed sequences back to sbhd format since two consecutive sequences in a rank may have different CP sizes, and using THD sequence packing format will lose such flexibility.
So there will be exactly times of forwards / backwards in each DCP rank, where is the number of sequences in current group (scheduling round).
Outra
So when to use different CP mode? I think MBridge docs already have the answer:
- https://docs.nvidia.com/nemo/megatron-bridge/latest/training/hybrid-context-parallel.html#when-to-use-it
- https://docs.nvidia.com/nemo/megatron-bridge/nightly/training/hierarchical-context-parallel.html#when-to-use-it
- https://docs.nvidia.com/nemo/megatron-bridge/latest/performance-guide.html#long-sequence-training
- https://docs.nvidia.com/nemo/megatron-bridge/latest/performance-guide.html#sequence-packing-for-performant-fine-tuning
This post also answers the question: why sbhd is the popular batch format among recipies?
A: AlltoAll and split operations ship continuous tensors if the first dimension is sequence dim.