Skip to content

vllm_omni.diffusion.distributed.sp_plan

Sequence Parallelism configuration and plan type definitions.

This module provides: 1. SequenceParallelConfig: Configuration for SP (ulysses_degree, ring_degree) 2. SequenceParallelInput/Output: Type definitions for _sp_plan declarations 3. Validation utilities for _sp_plan

A _sp_plan is a dictionary that specifies how to shard/gather tensors at different points in a model's forward pass. This allows automatic handling of sequence parallelism without modifying the model's forward() method.

NOTE: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers. We use "Sequence Parallelism" to align with vLLM-Omni terminology.

Example

class MyTransformer(nn.Module): _sp_plan = { # Split inputs before model forward "": { "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), "encoder_hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), }, # Split RoPE embeddings after pos_embed layer "pos_embed": { 0: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), }, # Gather output after proj_out layer "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), }

AnySequenceParallelInput module-attribute

AnySequenceParallelInput = (
    SequenceParallelInput | SequenceParallelPartialInput
)

SequenceParallelInputType module-attribute

SequenceParallelModelPlan module-attribute

SequenceParallelModelPlan = dict[
    str,
    SequenceParallelInputType | SequenceParallelOutputType,
]

SequenceParallelOutputType module-attribute

SequenceParallelConfig dataclass

Configuration for Sequence Parallelism using vLLM-Omni's parallel state.

This class provides a unified interface for SP configuration that integrates with vLLM-Omni's existing SequenceParallelGroupCoordinator. Unlike diffusers' DeviceMesh-based approach (ContextParallelConfig), this uses the existing parallel state management.

Note: This corresponds to ContextParallelConfig in diffusers library.

Parameters:

Name Type Description Default
ulysses_degree int

Number of devices for Ulysses (All-to-All) attention. Sequence is split across devices, with Q/K/V redistributed via All-to-All communication. Best for moderate sequences with good interconnect bandwidth.

1
ring_degree int

Number of devices for Ring attention. Sequence is split across devices, with K/V passed in a ring topology. Best for long sequences with limited memory/bandwidth.

1
convert_to_fp32 bool

Whether to convert output and LSE to float32 for numerical stability in ring attention.

True
Note

ulysses_degree * ring_degree = sequence_parallel_size vLLM-Omni supports hybrid Ulysses-Ring attention (both > 1).

convert_to_fp32 class-attribute instance-attribute

convert_to_fp32: bool = True

ring_degree class-attribute instance-attribute

ring_degree: int = 1

sequence_parallel_size property

sequence_parallel_size: int

Total sequence parallel world size.

ulysses_degree class-attribute instance-attribute

ulysses_degree: int = 1

get_rank

get_rank() -> int

Get the current rank in the sequence parallel group.

Returns:

Type Description
int

The rank within the sequence parallel group.

Raises:

Type Description
RuntimeError

If parallel state is not initialized.

get_ring_rank

get_ring_rank() -> int

Get the current rank in the Ring parallel group.

Returns:

Type Description
int

The rank within the Ring parallel group.

get_ring_world_size

get_ring_world_size() -> int

Get the Ring parallel world size.

Returns:

Type Description
int

The world size for Ring attention parallelism.

get_ulysses_rank

get_ulysses_rank() -> int

Get the current rank in the Ulysses parallel group.

Returns:

Type Description
int

The rank within the Ulysses parallel group.

get_ulysses_world_size

get_ulysses_world_size() -> int

Get the Ulysses parallel world size.

Returns:

Type Description
int

The world size for Ulysses (All-to-All) parallelism.

get_world_size

get_world_size() -> int

Get the sequence parallel world size from parallel state.

Returns:

Type Description
int

The world size for sequence parallelism.

Raises:

Type Description
RuntimeError

If parallel state is not initialized.

is_initialized

is_initialized() -> bool

Check if the config has been initialized with runtime state.

Returns:

Type Description
bool

True if setup() has been called, False otherwise.

setup

setup(rank: int, world_size: int, device: device) -> None

Initialize the config with runtime parallel state.

This is called automatically when sequence parallelism is enabled.

Parameters:

Name Type Description Default
rank int

The global rank of this process.

required
world_size int

Total world size.

required
device device

The device for this rank.

required

SequenceParallelInput dataclass

Configuration for splitting an input tensor across sequence parallel ranks.

This specifies how to shard a tensor in the pre-forward or post-forward hook of a layer. The tensor will be split along the specified dimension.

Note: This corresponds to ContextParallelInput in diffusers library.

Parameters:

Name Type Description Default
split_dim int

The dimension along which to split the tensor.

required
expected_dims int | None

Expected number of dimensions. If provided, validates that the tensor has this many dimensions before splitting. If the tensor has a different number of dimensions, splitting is skipped with a warning.

None
split_output bool

If True, split the output of the layer instead of the input. This is useful for layers whose outputs should be split after preprocessing (e.g., RoPE embeddings).

False
auto_pad bool

If True, automatically pad the tensor if its size along split_dim is not divisible by world_size. Creates an attention mask to indicate valid vs padding positions. The mask is stored in ForwardContext. Note: Ring attention does not support attention mask, so auto_pad should only be used with Ulysses SP.

False
Example

Split hidden_states along sequence dimension (dim 1)

SequenceParallelInput(split_dim=1, expected_dims=3)

Split RoPE output along sequence dimension (dim 0)

SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True)

Split with auto-padding for variable-length sequences

SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True)

auto_pad class-attribute instance-attribute

auto_pad: bool = False

expected_dims class-attribute instance-attribute

expected_dims: int | None = None

split_dim instance-attribute

split_dim: int

split_output class-attribute instance-attribute

split_output: bool = False

SequenceParallelOutput dataclass

Configuration for gathering an output tensor across sequence parallel ranks.

This specifies how to gather a tensor in the post-forward hook of a layer. The tensor will be gathered along the specified dimension from all ranks.

Note: This corresponds to ContextParallelOutput in diffusers library.

Parameters:

Name Type Description Default
gather_dim int

The dimension along which to gather the tensor.

required
expected_dims int | None

Expected number of dimensions. If provided, validates that the tensor has this many dimensions before gathering.

None
Example

Gather output along sequence dimension (dim 1)

SequenceParallelOutput(gather_dim=1, expected_dims=3)

expected_dims class-attribute instance-attribute

expected_dims: int | None = None

gather_dim instance-attribute

gather_dim: int

SequenceParallelPartialInput dataclass

Configuration for partially splitting a tensor (e.g., split image part, keep text part).

This is designed for models like LongCat/Qwen where RoPE embeddings need special handling: - Text portion: kept full across all ranks (for joint attention) - Image portion: split across ranks

The tensor is assumed to be concatenated as [text_part, image_part] along split_dim.

Note: This is an extension beyond diffusers' standard ContextParallelInput, designed for vLLM-Omni's dual-stream attention models.

Parameters:

Name Type Description Default
split_dim int

The dimension along which to split the image portion.

required
text_len_source str | int

How to determine text length: - str: Name of a forward parameter that contains text length - int: Fixed text length value

required
expected_dims int | None

Expected number of dimensions for validation.

None
split_output bool

If True, split the output instead of input.

False
Example

Split RoPE: text portion (from txt_ids.shape[0]) kept full, image portion split

SequenceParallelPartialInput( split_dim=0, text_len_source="txt_ids", # Get text length from txt_ids.shape[0] expected_dims=2, split_output=True, )

Or with fixed text length

SequenceParallelPartialInput( split_dim=0, text_len_source=512, # Fixed text length expected_dims=2, split_output=True, )

expected_dims class-attribute instance-attribute

expected_dims: int | None = None

split_dim instance-attribute

split_dim: int

split_output class-attribute instance-attribute

split_output: bool = False

text_len_source instance-attribute

text_len_source: str | int

get_sp_plan_from_model

get_sp_plan_from_model(
    model: Module,
) -> SequenceParallelModelPlan | None

Get the _sp_plan from a model if it exists.

Parameters:

Name Type Description Default
model Module

The model to get the plan from.

required

Returns:

Type Description
SequenceParallelModelPlan | None

The _sp_plan dictionary, or None if not defined.

validate_sp_plan

validate_sp_plan(plan: SequenceParallelModelPlan) -> None

Validate a _sp_plan dictionary for correctness.

Parameters:

Name Type Description Default
plan SequenceParallelModelPlan

The _sp_plan dictionary to validate.

required

Raises:

Type Description
ValueError

If the plan is invalid.