vllm_omni.diffusion.distributed ¶
Distributed utilities for vLLM-Omni diffusion models.
Modules:
| Name | Description |
|---|---|
autoencoders | |
cfg_parallel | Base pipeline class for Diffusion models with shared CFG functionality. |
comm | |
group_coordinator | |
hsdp | |
hsdp_utils | |
parallel_state | vLLM-Omni distributed state. |
pipeline_parallel | |
sp_plan | Sequence Parallelism configuration and plan type definitions. |
sp_sharding | Sequence Parallelism sharding utilities. |
utils | |
vae_patch_parallel | Distributed VAE patch/tile parallelism utilities. |
SequenceParallelModelPlan module-attribute ¶
SequenceParallelModelPlan = dict[
str,
SequenceParallelInputType | SequenceParallelOutputType,
]
HSDPInferenceConfig dataclass ¶
Configuration for HSDP inference.
This is a runtime config created from DiffusionParallelConfig's HSDP settings.
SequenceParallelConfig dataclass ¶
Configuration for Sequence Parallelism using vLLM-Omni's parallel state.
This class provides a unified interface for SP configuration that integrates with vLLM-Omni's existing SequenceParallelGroupCoordinator. Unlike diffusers' DeviceMesh-based approach (ContextParallelConfig), this uses the existing parallel state management.
Note: This corresponds to ContextParallelConfig in diffusers library.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ulysses_degree | int | Number of devices for Ulysses (All-to-All) attention. Sequence is split across devices, with Q/K/V redistributed via All-to-All communication. Best for moderate sequences with good interconnect bandwidth. | 1 |
ring_degree | int | Number of devices for Ring attention. Sequence is split across devices, with K/V passed in a ring topology. Best for long sequences with limited memory/bandwidth. | 1 |
convert_to_fp32 | bool | Whether to convert output and LSE to float32 for numerical stability in ring attention. | True |
Note
ulysses_degree * ring_degree = sequence_parallel_size vLLM-Omni supports hybrid Ulysses-Ring attention (both > 1).
get_rank ¶
get_rank() -> int
Get the current rank in the sequence parallel group.
Returns:
| Type | Description |
|---|---|
int | The rank within the sequence parallel group. |
Raises:
| Type | Description |
|---|---|
RuntimeError | If parallel state is not initialized. |
get_ring_rank ¶
get_ring_rank() -> int
Get the current rank in the Ring parallel group.
Returns:
| Type | Description |
|---|---|
int | The rank within the Ring parallel group. |
get_ring_world_size ¶
get_ring_world_size() -> int
Get the Ring parallel world size.
Returns:
| Type | Description |
|---|---|
int | The world size for Ring attention parallelism. |
get_ulysses_rank ¶
get_ulysses_rank() -> int
Get the current rank in the Ulysses parallel group.
Returns:
| Type | Description |
|---|---|
int | The rank within the Ulysses parallel group. |
get_ulysses_world_size ¶
get_ulysses_world_size() -> int
Get the Ulysses parallel world size.
Returns:
| Type | Description |
|---|---|
int | The world size for Ulysses (All-to-All) parallelism. |
get_world_size ¶
get_world_size() -> int
Get the sequence parallel world size from parallel state.
Returns:
| Type | Description |
|---|---|
int | The world size for sequence parallelism. |
Raises:
| Type | Description |
|---|---|
RuntimeError | If parallel state is not initialized. |
is_initialized ¶
is_initialized() -> bool
Check if the config has been initialized with runtime state.
Returns:
| Type | Description |
|---|---|
bool | True if setup() has been called, False otherwise. |
setup ¶
Initialize the config with runtime parallel state.
This is called automatically when sequence parallelism is enabled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rank | int | The global rank of this process. | required |
world_size | int | Total world size. | required |
device | device | The device for this rank. | required |
SequenceParallelInput dataclass ¶
Configuration for splitting an input tensor across sequence parallel ranks.
This specifies how to shard a tensor in the pre-forward or post-forward hook of a layer. The tensor will be split along the specified dimension.
Note: This corresponds to ContextParallelInput in diffusers library.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
split_dim | int | The dimension along which to split the tensor. | required |
expected_dims | int | None | Expected number of dimensions. If provided, validates that the tensor has this many dimensions before splitting. If the tensor has a different number of dimensions, splitting is skipped with a warning. | None |
split_output | bool | If True, split the output of the layer instead of the input. This is useful for layers whose outputs should be split after preprocessing (e.g., RoPE embeddings). | False |
auto_pad | bool | If True, automatically pad the tensor if its size along split_dim is not divisible by world_size. Creates an attention mask to indicate valid vs padding positions. The mask is stored in ForwardContext. Note: Ring attention does not support attention mask, so auto_pad should only be used with Ulysses SP. | False |
Example
Split hidden_states along sequence dimension (dim 1)¶
SequenceParallelInput(split_dim=1, expected_dims=3)
Split RoPE output along sequence dimension (dim 0)¶
SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True)
Split with auto-padding for variable-length sequences¶
SequenceParallelInput(split_dim=1, expected_dims=3, auto_pad=True)
SequenceParallelOutput dataclass ¶
Configuration for gathering an output tensor across sequence parallel ranks.
This specifies how to gather a tensor in the post-forward hook of a layer. The tensor will be gathered along the specified dimension from all ranks.
Note: This corresponds to ContextParallelOutput in diffusers library.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
gather_dim | int | The dimension along which to gather the tensor. | required |
expected_dims | int | None | Expected number of dimensions. If provided, validates that the tensor has this many dimensions before gathering. | None |
Example
Gather output along sequence dimension (dim 1)¶
SequenceParallelOutput(gather_dim=1, expected_dims=3)
SequenceParallelPartialInput dataclass ¶
Configuration for partially splitting a tensor (e.g., split image part, keep text part).
This is designed for models like LongCat/Qwen where RoPE embeddings need special handling: - Text portion: kept full across all ranks (for joint attention) - Image portion: split across ranks
The tensor is assumed to be concatenated as [text_part, image_part] along split_dim.
Note: This is an extension beyond diffusers' standard ContextParallelInput, designed for vLLM-Omni's dual-stream attention models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
split_dim | int | The dimension along which to split the image portion. | required |
text_len_source | str | int | How to determine text length: - str: Name of a forward parameter that contains text length - int: Fixed text length value | required |
expected_dims | int | None | Expected number of dimensions for validation. | None |
split_output | bool | If True, split the output instead of input. | False |
Example
Split RoPE: text portion (from txt_ids.shape[0]) kept full, image portion split¶
SequenceParallelPartialInput( split_dim=0, text_len_source="txt_ids", # Get text length from txt_ids.shape[0] expected_dims=2, split_output=True, )
Or with fixed text length¶
SequenceParallelPartialInput( split_dim=0, text_len_source=512, # Fixed text length expected_dims=2, split_output=True, )
ShardingValidator dataclass ¶
Validator for tracking and verifying sharding operations.
This class helps ensure that sharding and gathering operations are correctly paired in model forward passes. It tracks which tensors have been sharded and verifies that they are properly gathered.
Usage
validator = ShardingValidator() with validator.track(): hidden_states = validator.shard(hidden_states, "hidden_states", dim=1) # ... model computation ... output = validator.gather(output, "hidden_states", dim=1) validator.validate() # Raises if any shard was not gathered
Attributes:
| Name | Type | Description |
|---|---|---|
_sharded | set[str] | Set of tensor names that have been sharded. |
_gathered | set[str] | Set of tensor names that have been gathered. |
_enabled | bool | Whether tracking is currently enabled. |
gather ¶
Gather a tensor and track the operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The local shard to gather. | required |
name | str | The name used when sharding (for validation). | required |
dim | int | The dimension along which to gather. | required |
Returns:
| Type | Description |
|---|---|
Tensor | The gathered tensor. |
shard ¶
Shard a tensor and track the operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The tensor to shard. | required |
name | str | A name to identify this tensor for validation. | required |
dim | int | The dimension along which to split. | required |
validate_divisible | bool | If True, validate divisibility. | True |
Returns:
| Type | Description |
|---|---|
Tensor | The sharded tensor. |
validate ¶
Validate that all sharded tensors were gathered.
Raises:
| Type | Description |
|---|---|
ValueError | If any sharded tensor was not gathered. |
apply_hsdp_to_model ¶
apply_hsdp_to_model(
model: Module, hsdp_config: HSDPInferenceConfig
) -> Module
Apply HSDP sharding to a model that already has weights loaded.
This function redistributes the model's parameters across GPUs using HSDP. The model should already have its weights loaded via the standard load_weights method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | Model instance with weights already loaded | required |
hsdp_config | HSDPInferenceConfig | HSDP configuration with HSDP mesh dimensions | required |
Returns:
| Type | Description |
|---|---|
Module | HSDP-wrapped model ready for inference |
get_fully_shard_world_size ¶
Return world size for the fully shard group.
get_sharding_validator ¶
get_sharding_validator() -> ShardingValidator
Get the global sharding validator instance.
Returns:
| Type | Description |
|---|---|
ShardingValidator | The global ShardingValidator. |
get_sp_plan_from_model ¶
get_sp_plan_from_model(
model: Module,
) -> SequenceParallelModelPlan | None
Get the _sp_plan from a model if it exists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | The model to get the plan from. | required |
Returns:
| Type | Description |
|---|---|
SequenceParallelModelPlan | None | The _sp_plan dictionary, or None if not defined. |
sp_gather ¶
Gather a tensor along the specified dimension from all sequence parallel ranks.
The sharded tensors from all ranks are concatenated along dim.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The local shard to gather. | required |
dim | int | The dimension along which to gather. | required |
validate | bool | If True, validate tensor consistency (currently unused). | True |
Returns:
| Type | Description |
|---|---|
Tensor | The full tensor gathered from all ranks. |
sp_shard ¶
Shard a tensor along the specified dimension for sequence parallelism.
The tensor is split into world_size chunks along dim, and this rank receives its corresponding chunk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The tensor to shard. | required |
dim | int | The dimension along which to split. | required |
validate | bool | If True, validate that the tensor size is divisible by world_size. | True |
Returns:
| Type | Description |
|---|---|
Tensor | The shard for this rank. |
Raises:
| Type | Description |
|---|---|
ValueError | If validate=True and tensor size is not divisible by world_size. |
sp_shard_with_padding ¶
Shard a tensor with automatic padding if not divisible by world_size.
This is useful for variable-length sequences where padding may be needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The tensor to shard. | required |
dim | int | The dimension along which to split. | required |
pad_value | float | Value to use for padding. | 0.0 |
Returns:
| Type | Description |
|---|---|
Tensor | Tuple of (sharded_tensor, padding_size). The padding_size indicates |
int | how much padding was added to the original tensor before sharding. |
Example
sharded, pad_size = sp_shard_with_padding(hidden_states, dim=1)
... process ...¶
output = sp_gather(output, dim=1) if pad_size > 0: output = output[..., :-pad_size] # Remove padding
validate_sp_plan ¶
validate_sp_plan(plan: SequenceParallelModelPlan) -> None
Validate a _sp_plan dictionary for correctness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
plan | SequenceParallelModelPlan | The _sp_plan dictionary to validate. | required |
Raises:
| Type | Description |
|---|---|
ValueError | If the plan is invalid. |