Skip to content

vllm_omni.diffusion.distributed

Distributed utilities for vLLM-Omni diffusion models.

Modules:

Name Description
autoencoders
cfg_parallel

Base pipeline class for Diffusion models with shared CFG functionality.

comm
group_coordinator
hsdp
hsdp_utils
parallel_state

vLLM-Omni distributed state.

pipeline_parallel
sp_plan

Sequence Parallelism configuration and plan type definitions.

sp_sharding

Sequence Parallelism sharding utilities.

utils
vae_patch_parallel

Distributed VAE patch/tile parallelism utilities.

SequenceParallelModelPlan module-attribute

SequenceParallelModelPlan = dict[
    str,
    SequenceParallelInputType | SequenceParallelOutputType,
]

HSDPInferenceConfig dataclass

Configuration for HSDP inference.

This is a runtime config created from DiffusionParallelConfig's HSDP settings.

enabled class-attribute instance-attribute

enabled: bool = False

hsdp_replicate_size class-attribute instance-attribute

hsdp_replicate_size: int = 1

hsdp_shard_size class-attribute instance-attribute

hsdp_shard_size: int = -1

output_dtype class-attribute instance-attribute

output_dtype: dtype | None = None

param_dtype class-attribute instance-attribute

param_dtype: dtype = bfloat16

reduce_dtype class-attribute instance-attribute

reduce_dtype: dtype = float32

reshard_after_forward class-attribute instance-attribute

reshard_after_forward: bool = True

SequenceParallelConfig dataclass

Configuration for Sequence Parallelism using vLLM-Omni's parallel state.

This class provides a unified interface for SP configuration that integrates with vLLM-Omni's existing SequenceParallelGroupCoordinator. Unlike diffusers' DeviceMesh-based approach (ContextParallelConfig), this uses the existing parallel state management.

Note: This corresponds to ContextParallelConfig in diffusers library.

Parameters:

Name Type Description Default
ulysses_degree int

Number of devices for Ulysses (All-to-All) attention. Sequence is split across devices, with Q/K/V redistributed via All-to-All communication. Best for moderate sequences with good interconnect bandwidth.

1
ring_degree int

Number of devices for Ring attention. Sequence is split across devices, with K/V passed in a ring topology. Best for long sequences with limited memory/bandwidth.

1
convert_to_fp32 bool

Whether to convert output and LSE to float32 for numerical stability in ring attention.

True
Note

ulysses_degree * ring_degree = sequence_parallel_size vLLM-Omni supports hybrid Ulysses-Ring attention (both > 1).

convert_to_fp32 class-attribute instance-attribute

convert_to_fp32: bool = True

ring_degree class-attribute instance-attribute

ring_degree: int = 1

sequence_parallel_size property

sequence_parallel_size: int

Total sequence parallel world size.

ulysses_degree class-attribute instance-attribute

ulysses_degree: int = 1

get_rank

get_rank() -> int

Get the current rank in the sequence parallel group.

Returns:

Type Description
int

The rank within the sequence parallel group.

Raises:

Type Description
RuntimeError

If parallel state is not initialized.

get_ring_rank

get_ring_rank() -> int

Get the current rank in the Ring parallel group.

Returns:

Type Description
int

The rank within the Ring parallel group.

get_ring_world_size

get_ring_world_size() -> int

Get the Ring parallel world size.

Returns:

Type Description
int

The world size for Ring attention parallelism.

get_ulysses_rank

get_ulysses_rank() -> int

Get the current rank in the Ulysses parallel group.

Returns:

Type Description
int

The rank within the Ulysses parallel group.

get_ulysses_world_size

get_ulysses_world_size() -> int

Get the Ulysses parallel world size.

Returns:

Type Description
int

The world size for Ulysses (All-to-All) parallelism.

get_world_size

get_world_size() -> int

Get the sequence parallel world size from parallel state.

Returns:

Type Description
int

The world size for sequence parallelism.

Raises:

Type Description
RuntimeError

If parallel state is not initialized.

is_initialized

is_initialized() -> bool

Check if the config has been initialized with runtime state.

Returns:

Type Description
bool

True if setup() has been called, False otherwise.

setup

setup(rank: int, world_size: int, device: device) -> None

Initialize the config with runtime parallel state.

This is called automatically when sequence parallelism is enabled.

Parameters:

Name Type Description Default
rank int

The global rank of this process.

required
world_size int

Total world size.

required
device device

The device for this rank.

required

SequenceParallelInput dataclass

Configuration for splitting an input tensor across sequence parallel ranks.

This specifies how to shard a tensor in the pre-forward or post-forward hook of a layer. The tensor will be split along the specified dimension.

Note: This corresponds to ContextParallelInput in diffusers library.

Parameters:

Name Type Description Default
split_dim int

The dimension along which to split the tensor.

required
expected_dims int | None

Expected number of dimensions. If provided, validates that the tensor has this many dimensions before splitting. If the tensor has a different number of dimensions, splitting is skipped with a warning.

None
split_output bool

If True, split the output of the layer instead of the input. This is useful for layers whose outputs should be split after preprocessing (e.g., RoPE embeddings).

False
auto_pad bool

If True, automatically pad the tensor if its size along split_dim is not divisible by world_size. Creates an attention mask to indicate valid vs padding positions. The mask is stored in ForwardContext. Note: Ring attention does not support attention mask, so auto_pad should only be used with Ulysses SP.

False
Example

Split hidden_states along sequence dimension (dim 1)

SequenceParallelInput(split_dim=1, expected_dims=3)

Split RoPE output along sequence dimension (dim 0)

SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True)

Split with auto-padding for variable-length sequences

SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True)

auto_pad class-attribute instance-attribute

auto_pad: bool = False

expected_dims class-attribute instance-attribute

expected_dims: int | None = None

split_dim instance-attribute

split_dim: int

split_output class-attribute instance-attribute

split_output: bool = False

SequenceParallelOutput dataclass

Configuration for gathering an output tensor across sequence parallel ranks.

This specifies how to gather a tensor in the post-forward hook of a layer. The tensor will be gathered along the specified dimension from all ranks.

Note: This corresponds to ContextParallelOutput in diffusers library.

Parameters:

Name Type Description Default
gather_dim int

The dimension along which to gather the tensor.

required
expected_dims int | None

Expected number of dimensions. If provided, validates that the tensor has this many dimensions before gathering.

None
Example

Gather output along sequence dimension (dim 1)

SequenceParallelOutput(gather_dim=1, expected_dims=3)

expected_dims class-attribute instance-attribute

expected_dims: int | None = None

gather_dim instance-attribute

gather_dim: int

SequenceParallelPartialInput dataclass

Configuration for partially splitting a tensor (e.g., split image part, keep text part).

This is designed for models like LongCat/Qwen where RoPE embeddings need special handling: - Text portion: kept full across all ranks (for joint attention) - Image portion: split across ranks

The tensor is assumed to be concatenated as [text_part, image_part] along split_dim.

Note: This is an extension beyond diffusers' standard ContextParallelInput, designed for vLLM-Omni's dual-stream attention models.

Parameters:

Name Type Description Default
split_dim int

The dimension along which to split the image portion.

required
text_len_source str | int

How to determine text length: - str: Name of a forward parameter that contains text length - int: Fixed text length value

required
expected_dims int | None

Expected number of dimensions for validation.

None
split_output bool

If True, split the output instead of input.

False
Example

Split RoPE: text portion (from txt_ids.shape[0]) kept full, image portion split

SequenceParallelPartialInput( split_dim=0, text_len_source="txt_ids", # Get text length from txt_ids.shape[0] expected_dims=2, split_output=True, )

Or with fixed text length

SequenceParallelPartialInput( split_dim=0, text_len_source=512, # Fixed text length expected_dims=2, split_output=True, )

expected_dims class-attribute instance-attribute

expected_dims: int | None = None

split_dim instance-attribute

split_dim: int

split_output class-attribute instance-attribute

split_output: bool = False

text_len_source instance-attribute

text_len_source: str | int

ShardingValidator dataclass

Validator for tracking and verifying sharding operations.

This class helps ensure that sharding and gathering operations are correctly paired in model forward passes. It tracks which tensors have been sharded and verifies that they are properly gathered.

Usage

validator = ShardingValidator() with validator.track(): hidden_states = validator.shard(hidden_states, "hidden_states", dim=1) # ... model computation ... output = validator.gather(output, "hidden_states", dim=1) validator.validate() # Raises if any shard was not gathered

Attributes:

Name Type Description
_sharded set[str]

Set of tensor names that have been sharded.

_gathered set[str]

Set of tensor names that have been gathered.

_enabled bool

Whether tracking is currently enabled.

gather

gather(tensor: Tensor, name: str, dim: int) -> Tensor

Gather a tensor and track the operation.

Parameters:

Name Type Description Default
tensor Tensor

The local shard to gather.

required
name str

The name used when sharding (for validation).

required
dim int

The dimension along which to gather.

required

Returns:

Type Description
Tensor

The gathered tensor.

reset

reset() -> None

Reset the validator state for a new forward pass.

shard

shard(
    tensor: Tensor,
    name: str,
    dim: int,
    validate_divisible: bool = True,
) -> Tensor

Shard a tensor and track the operation.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to shard.

required
name str

A name to identify this tensor for validation.

required
dim int

The dimension along which to split.

required
validate_divisible bool

If True, validate divisibility.

True

Returns:

Type Description
Tensor

The sharded tensor.

track

track()

Context manager to enable tracking for a forward pass.

validate

validate() -> None

Validate that all sharded tensors were gathered.

Raises:

Type Description
ValueError

If any sharded tensor was not gathered.

apply_hsdp_to_model

apply_hsdp_to_model(
    model: Module, hsdp_config: HSDPInferenceConfig
) -> Module

Apply HSDP sharding to a model that already has weights loaded.

This function redistributes the model's parameters across GPUs using HSDP. The model should already have its weights loaded via the standard load_weights method.

Parameters:

Name Type Description Default
model Module

Model instance with weights already loaded

required
hsdp_config HSDPInferenceConfig

HSDP configuration with HSDP mesh dimensions

required

Returns:

Type Description
Module

HSDP-wrapped model ready for inference

get_fs_group

get_fs_group() -> GroupCoordinator

get_fully_shard_rank

get_fully_shard_rank()

Return my rank for the fully shard group.

get_fully_shard_world_size

get_fully_shard_world_size()

Return world size for the fully shard group.

get_sharding_validator

get_sharding_validator() -> ShardingValidator

Get the global sharding validator instance.

Returns:

Type Description
ShardingValidator

The global ShardingValidator.

get_sp_plan_from_model

get_sp_plan_from_model(
    model: Module,
) -> SequenceParallelModelPlan | None

Get the _sp_plan from a model if it exists.

Parameters:

Name Type Description Default
model Module

The model to get the plan from.

required

Returns:

Type Description
SequenceParallelModelPlan | None

The _sp_plan dictionary, or None if not defined.

sp_gather

sp_gather(
    tensor: Tensor, dim: int, validate: bool = True
) -> Tensor

Gather a tensor along the specified dimension from all sequence parallel ranks.

The sharded tensors from all ranks are concatenated along dim.

Parameters:

Name Type Description Default
tensor Tensor

The local shard to gather.

required
dim int

The dimension along which to gather.

required
validate bool

If True, validate tensor consistency (currently unused).

True

Returns:

Type Description
Tensor

The full tensor gathered from all ranks.

Example

At end of model forward:

output = sp_gather(output, dim=1)

sp_shard

sp_shard(
    tensor: Tensor, dim: int, validate: bool = True
) -> Tensor

Shard a tensor along the specified dimension for sequence parallelism.

The tensor is split into world_size chunks along dim, and this rank receives its corresponding chunk.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to shard.

required
dim int

The dimension along which to split.

required
validate bool

If True, validate that the tensor size is divisible by world_size.

True

Returns:

Type Description
Tensor

The shard for this rank.

Raises:

Type Description
ValueError

If validate=True and tensor size is not divisible by world_size.

Example

In model forward:

hidden_states = sp_shard(hidden_states, dim=1)

sp_shard_with_padding

sp_shard_with_padding(
    tensor: Tensor, dim: int, pad_value: float = 0.0
) -> tuple[Tensor, int]

Shard a tensor with automatic padding if not divisible by world_size.

This is useful for variable-length sequences where padding may be needed.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to shard.

required
dim int

The dimension along which to split.

required
pad_value float

Value to use for padding.

0.0

Returns:

Type Description
Tensor

Tuple of (sharded_tensor, padding_size). The padding_size indicates

int

how much padding was added to the original tensor before sharding.

Example

sharded, pad_size = sp_shard_with_padding(hidden_states, dim=1)

... process ...

output = sp_gather(output, dim=1) if pad_size > 0: output = output[..., :-pad_size] # Remove padding

validate_sp_plan

validate_sp_plan(plan: SequenceParallelModelPlan) -> None

Validate a _sp_plan dictionary for correctness.

Parameters:

Name Type Description Default
plan SequenceParallelModelPlan

The _sp_plan dictionary to validate.

required

Raises:

Type Description
ValueError

If the plan is invalid.