Sequence Parallel¶
This section describes how to add Sequence Parallel (SP) to a diffusion transformer model. We use the Qwen-Image transformer and Wan2.2 transformer as reference implementations.
Table of Contents¶
- Overview
- UAA Mode (Experimental)
- Approach 1: Non-Intrusive
_sp_plan(Recommended) - Approach 2: Intrusive Modification (For Complex Cases)
- Testing
- Troubleshooting
- Reference Implementations
- Summary
Overview¶
What is Sequence Parallel?¶
Terminology Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in the diffusers library. We use "Sequence Parallelism" to align with vLLM-Omni's terminology.
Diffusion transformers process long sequences of image patches or video frames. For high-resolution generation, these sequences can become very large. Enabling SP allows each GPU to process only a portion of the sequence, with attention mechanisms (Ulysses/Ring) handling cross-GPU communication transparently.
Architecture¶
The major APIs for Sequence Parallel:
from vllm_omni.diffusion.distributed.sp_plan import (
SequenceParallelInput, # For sharding (splitting) tensors
SequenceParallelOutput, # For gathering tensors
)
from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather
| Method/Class | Purpose | Behavior |
|---|---|---|
SequenceParallelInput | Declare input sharding in _sp_plan | Auto-shards tensors at module input |
SequenceParallelOutput | Declare output gathering in _sp_plan | Auto-gathers tensors at module output |
sp_shard() | Manual tensor sharding | Splits tensor across SP workers |
sp_gather() | Manual tensor gathering | Gathers sharded tensors from all workers |
UAA Mode (Experimental)¶
ulysses_mode="advanced_uaa" enables the experimental UAA ("Ulysses Anything Attention") feature, which lets Ulysses attention handle arbitrary sequence lengths and arbitrary attention head counts. The same idea is also supported by Cache-DiT.
Use it when plain Ulysses-SP would otherwise fail because:
- the local sequence shards are not evenly divisible after split hooks, or
- the attention head count is not divisible by
ulysses_degree.
Design Summary¶
-
Strict mode stays unchanged.
ulysses_mode="strict"keeps the original fast path and still requires divisible sequence/head shapes. -
UAA uses variable all-to-all split sizes for sequence shards. Before the Ulysses Q/K/V exchange, each rank all-gathers its local sequence length and uses those lengths as
all_to_all_single(..., output_split_sizes=seq_lens). This lets Ulysses gather the full sequence even when each rank started with a different local shard length. -
UAA pads heads only inside the Ulysses exchange. If
head_cnt % ulysses_degree != 0, UAA pads the head dimension up to the next multiple ofulysses_degree, performs the forward/reverse all-to-all, then slices the temporary head padding away after the reverse exchange. The same rule is applied to joint attention tensors. -
Hybrid Ulysses + Ring is still shape-constrained. Ring attention expects every rank in a ring group to exchange exactly the same post-Ulysses sequence shape. UAA therefore validates those shapes before entering the ring path and raises a clear error if ring peers disagree on
S_global. -
Tiny scalar gathers stay out of TorchDynamo tracing.
_all_gather_int()is marked with@torch.compiler.disableso the scalaritem()conversions used by UAA metadata collection do not get traced intotorch.compile.
UAA vs auto_pad¶
auto_pad=Truepads sequence tokens in_sp_planand requires attention backends that supportattention_mask.advanced_uaadoes not depend on mask-based token padding inside Ulysses attention. It is therefore a better fit for non-divisible head counts and uneven Ulysses shard sizes.auto_pad=Trueremains incompatible with Ring attention because the ring backend does not consumeattention_mask.advanced_uaais still experimental and hybrid mode remains limited by Ring's equal-shape requirement.
Approach 1: Non-Intrusive _sp_plan (Recommended)¶
The _sp_plan mechanism allows SP without modifying forward() logic. The framework automatically registers hooks to shard inputs and gather outputs at module boundaries.
When to use: - Standard transformer architectures - Tensor operations happen at nn.Module boundaries - Predictable sharding/gathering patterns
This is the ideal approach for integrating sequence parallelism into new models, as it is easier to maintain and ensure compatibility with other types of acceleration.
How it works: 1. Declare _sp_plan dict in your transformer class 2. Framework automatically applies hooks when sequence_parallel_size > 1 3. Hooks shard/gather tensors at specified module boundaries 4. Attention layers handle cross-GPU communication internally
class StandardTransformer(nn.Module):
_sp_plan = {
# Shard hidden_states at first transformer block input
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
# Gather at final output projection
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
StandardTransformer has a transformer blocks list self.blocks = nn.ModuleList([...]), and a projection output layer self.proj_out. The _sp_plan above defines that when SP is enabled, sharding the input tensor to the first transformer block, and gathering the sharded tensor at the final output projection layer.
Requirements: - Tensor operations that need sharding/gathering must happen at nn.Module boundaries - Inline Python operations (e.g., torch.cat, pad_sequence) cannot be hooked
Solution for inline operations: Extract into a submodule (see Step 2 below).
Step 1: Understand Module Boundaries¶
Identify where tensors need to be sharded or gathered in your model's forward pass:
class MyTransformer(nn.Module):
def __init__(self):
self.patch_embed = PatchEmbed() # ← Boundary 1
self.pos_embed = RoPE() # ← Boundary 2
self.blocks = nn.ModuleList([...]) # ← Boundary 3
self.norm_out = LayerNorm()
self.proj_out = Linear() # ← Boundary 4
def forward(self, x):
x = self.patch_embed(x) # ← Shard before this?
pos = self.pos_embed(x) # ← Shard RoPE outputs?
for block in self.blocks:
x = block(x, pos) # ← Blocks process sharded x
x = self.norm_out(x)
output = self.proj_out(x) # ← Gather after this?
return output
Step 2: Handle Inline Operations¶
If your forward() contains inline tensor operations, extract them into submodules.
Example: Z-Image concatenates image + text features inline
# ❌ BAD: Inline operation - hooks cannot intercept
class ZImageTransformer(nn.Module):
def forward(self, x, cap_feats):
# This concatenation happens inline - _sp_plan can't shard it!
unified = torch.cat([x, cap_feats], dim=1)
for layer in self.layers:
unified = layer(unified)
return unified
# ✅ GOOD: Extract into submodule
class UnifiedPrepare(nn.Module):
"""Submodule to concatenate image and text features."""
def forward(self, x, cap_feats):
return torch.cat([x, cap_feats], dim=1)
class ZImageTransformer(nn.Module):
def __init__(self):
super().__init__()
self.unified_prepare = UnifiedPrepare() # Now a module!
self.layers = nn.ModuleList([...])
def forward(self, x, cap_feats):
# Now _sp_plan can shard the output of unified_prepare!
unified = self.unified_prepare(x, cap_feats)
for layer in self.layers:
unified = layer(unified)
return unified
Other common cases: - pad_sequence() → PadSequenceModule - torch.cat() → ConcatModule - tensor.reshape() → ReshapeModule - Complex preprocessing → PreprocessModule
Step 3: Write _sp_plan for Your Model¶
Create a class-level _sp_plan dictionary specifying where to shard/gather tensors.
Typically, there are two patterns for diffusion models:
Pattern 1: Shard at first block, gather at output projection
Most common pattern for standard transformers:
from vllm_omni.diffusion.distributed.sp_plan import (
SequenceParallelInput, # For sharding (splitting) tensors
SequenceParallelOutput, # For gathering tensors
)
class StandardTransformer(nn.Module):
_sp_plan = {
# Shard hidden_states at first transformer block input
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
# Gather at final output projection
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
Pattern 2: Shard RoPE embeddings separately
When RoPE is computed in a separate module:
from vllm_omni.diffusion.distributed.sp_plan import (
SequenceParallelInput, # For sharding (splitting) tensors
SequenceParallelOutput, # For gathering tensors
)
class TransformerWithRoPE(nn.Module):
_sp_plan = {
# Shard RoPE module OUTPUTS (returns tuple of cos, sin)
"rope": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # cos
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), # sin
},
# Shard transformer block INPUT
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
# Gather at output
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
Pattern 3: Shard RoPE for Dual Stream Attention In some cases, different streams in attention may need to handle sequence parallelism differently. For example, we may want to shard the image embeddings, while replicating the text embeddings to correctly configure joint attention.
class DualStreamTransformer(nn.Module):
"""
Dual-stream model where we need to replicate the text components, but shard
the image components to correctly handle sequence parallelism.
"""
_sp_plan = {
# In this case, the rope_preparer returns a tuple of len 4, where the
# first 2 items correspond to the text, and the second 2 correspond to
# visual inputs, so we only shard the second.
"rope_preparer": {
# Outputs 0, 1 (text) - NOT sharded (replicated)
# Outputs 2, 3 (image) - sharded
2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), # img_cos
3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), # img_sin
},
# Shard transformer block INPUT
"transformer_blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
# Gather at output
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
NOTE: be careful to test adequately when refactoring classes that take this style of plan, as changing the order of the return values will break sequence parallelism.
API Reference¶
SequenceParallelInput Parameters:
| Parameter | Type | Description |
|---|---|---|
split_dim | int | Dimension to split (usually 1 for sequence) |
expected_dims | int | None | Expected tensor rank for validation (optional) |
split_output | bool | False: shard input params; True: shard output tensors |
auto_pad | bool | Auto-pad if sequence not divisible by world_size (default: False) |
SequenceParallelOutput Parameters:
| Parameter | Type | Description |
|---|---|---|
gather_dim | int | Dimension to gather (usually 1 for sequence) |
expected_dims | int | None | Expected tensor rank for validation (optional) |
Module Naming Conventions:
| Key | Meaning | Python equivalent |
|---|---|---|
"" | Root model | model |
"blocks.0" | First element of ModuleList | model.blocks[0] |
"blocks.*" | All elements of ModuleList | for b in model.blocks |
"rope" | Named submodule | model.rope |
"outputs.main" | ModuleDict entry | model.outputs["main"] |
Dictionary Value Types:
| Key type | split_output | Description |
|---|---|---|
"param_name" (str) | False | Shard input parameter by name |
0, 1, ... (int) | True | Shard output tuple by index |
Approach 2: Intrusive Modification (For Complex Cases)¶
For models with dynamic sharding logic that cannot be expressed via _sp_plan, manually insert shard/gather calls.
When to use: - Dynamic/conditional sharding logic - Complex tensor manipulations that can't be encapsulated - Temporary workaround during development
from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather
def forward(self, hidden_states, ...):
if self.parallel_config.sequence_parallel_size > 1:
hidden_states = sp_shard(hidden_states, dim=1)
# ... computation ...
if self.parallel_config.sequence_parallel_size > 1:
output = sp_gather(output, dim=1)
return output
Testing¶
After implementing Sequence Parallel support, thoroughly test your implementation to ensure correctness and performance across different configurations.
Test Different sp_size:
Test your model with various sequence parallel world sizes to verify correctness and identify optimal configurations:
cd examples/offline_inference/text_to_image
python text_to_image.py \
--model Your-org/your-model \
--prompt "a cup of coffee on the table" \
--num-inference-steps 50 \
--ulysses-degree 2 \
--ring-degree 2 \
--output sp_test_image_ulysses=2_ring=2.png
Verify:
- Correctness: Output should be identical across all
sp_sizevalues - Speed: Throughput should remain stable or improve (especially for large sequences)
- Logs: Check for any shape mismatch or communication errors
Test with Tensor Parallel:
Sequence Parallel can be combined with other parallelism strategies:
cd examples/offline_inference/text_to_image
python text_to_image.py \
--model Your-org/your-model \
--prompt "a cup of coffee on the table" \
--num-inference-steps 50 \
--ulysses-degree 2 \
--tensor-parallel-size 2 \
--output sp_test_image_ulysses=2_tp=2.png
Troubleshooting¶
Issue: Shape mismatch errors¶
Symptoms: RuntimeError: shape mismatch during forward pass.
Causes & Solutions:
- RoPE dimension mismatch:
Problem: RoPE embeddings not sharded, but hidden_states is sharded.
Solution: Shard RoPE outputs in _sp_plan:
_sp_plan = {
"rope": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
},
...
}
- Sequence Length not divisible by sp_size:
Problem: strict Ulysses sequence parallel requires divisible shapes. If the shard length is uneven, or if the model head count is not divisible by ulysses_degree, the strict path will raise an error.
Solutions:
- Use
ulysses_mode="advanced_uaa"for Ulysses-SP when you want the experimental uneven-shape path without relying on attention-mask padding. - If the model already uses
_sp_plantoken padding and the attention backend supportsattention_mask, setauto_pad=Trueand add attention-mask plumbing.
Experimental Feature:
ulysses_mode="advanced_uaa"is experimental. It is intended to relax Ulysses divisibility constraints, but hybrid Ulysses + Ring still requires equal post-Ulysses sequence lengths inside each ring group.Experimental Feature:
auto_pad=Trueis an experimental feature and may be changed in the future. We plan to improve this solution to involve minimal changes to model files. More details are here.
Constraints of auto_pad:
| Constraint | Description |
|---|---|
| Attention Backend Compatibility | The attention backends must support attention_mask. Currently only TORCH_SDPA and FLASH_ATTN (default for diffusion models) are supported. |
| Ring Attention Limitation | Ring attention does not support attention_mask. Therefore, when using auto_pad=True, the combination of Ulysses + Ring attention is not feasible. |
-
Enable
auto_pad=Truefor all sequence-dimension inputs in_sp_plan:_sp_plan = { "rope": { 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True, auto_pad=True), 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True, auto_pad=True), }, "blocks.0": { "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True) }, ... } -
Create attention mask dynamically when padding is applied:
from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata # In model forward(), before transformer blocks: hidden_states_mask = None ctx = get_forward_context() if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: batch_size = hidden_states.shape[0] padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size hidden_states_mask = torch.ones(batch_size, padded_seq_len, dtype=torch.bool, device=hidden_states.device) hidden_states_mask[:, ctx.sp_original_seq_len:] = False # Pass mask to attention layers attn_metadata = AttentionMetadata(attn_mask=hidden_states_mask) if hidden_states_mask is not None else None output = self.attn(query, key, value, attn_metadata)
Important Quality Considerations:
While auto_pad enables generation for irregular resolutions, be aware of potential quality impacts:
| Aspect | Impact |
|---|---|
| Training Distribution | Models perform best on aspect ratios seen during training (e.g., 1:1, 16:9, 4:3). Unusual ratios like 700x400 (1.75:1) may produce lower quality results. |
| Padding Overhead | Padded positions consume compute even when masked. For best efficiency, prefer resolutions divisible by sp_size. |
Recommendations for users: - Use standard aspect ratios when possible (e.g., 768x432 for 16:9 instead of 700x400) - Ensure post-patch dimensions are divisible by sp_size for optimal quality - Test generation quality when using unusual resolutions
Issue: Inline operations not sharded¶
Symptoms: Some tensors remain full-sized, not sharded.
Causes & Solutions:
- Operations happen inline in
forward(), not at module boundaries:
Problem:
def forward(self, x, cap):
unified = torch.cat([x, cap], dim=1) # ← Inline operation!
# _sp_plan can't hook this
Solution: Extract into submodule:
class ConcatModule(nn.Module):
def forward(self, x, cap):
return torch.cat([x, cap], dim=1)
class MyModel(nn.Module):
_sp_plan = {
"concat": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
},
...
}
def __init__(self):
self.concat = ConcatModule() # Now hookable!
def forward(self, x, cap):
unified = self.concat(x, cap) # ← Can be sharded via _sp_plan
Reference Implementations¶
Complete examples in the codebase:
| Model | Path | Pattern | Notes |
|---|---|---|---|
| LongCat | vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py | Dual-stream | Text components replicated, image components sharded |
| Qwen-Image | vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py | Dual-stream + preprocessing | auto_pad, separate RoPE |
| Wan2.2 | vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py | Dual-Transformer + RoPE | Video transformer |
| Z-Image | vllm_omni/diffusion/models/z_image/z_image_transformer.py | Unified sequence | Concatenated input |
| SP Plan Types | vllm_omni/diffusion/distributed/sp_plan.py | Type definitions | SequenceParallelInput/Output |
| Hook Implementation | vllm_omni/diffusion/hooks/sequence_parallel.py | Hook mechanics | How hooks work |
| Tests | tests/diffusion/distributed/test_sp_plan_hooks.py | Test examples | Validation patterns |
Summary¶
Adding Sequence Parallel support to a transformer:
- ✅ Choose approach - Use
_sp_planfor standard cases, intrusive modification for complex cases - ✅ Identify sharding boundaries - Where should tensors be split/gathered? And which module boundaries need to be moved to facilitate this?
- ✅ Extract inline operations - Move
torch.cat,pad_sequence, etc. to submodules - ✅ Define
_sp_plan- Declare shard/gather points as class attribute - ✅ Use
auto_padfor variable lengths - Support non-uniform sequences - ✅ Test - Verify with different
ulysses_degreeandring_degreecombinations