Skip to content

vllm_omni.diffusion.distributed.sp_sharding

Sequence Parallelism sharding utilities.

This module provides low-level sharding and gathering functions for Sequence Parallelism. These can be used directly in model forward methods for semi-intrusive SP support, or internally by the SP hooks.

logger module-attribute

logger = init_logger(__name__)

ShardingValidator dataclass

Validator for tracking and verifying sharding operations.

This class helps ensure that sharding and gathering operations are correctly paired in model forward passes. It tracks which tensors have been sharded and verifies that they are properly gathered.

Usage

validator = ShardingValidator() with validator.track(): hidden_states = validator.shard(hidden_states, "hidden_states", dim=1) # ... model computation ... output = validator.gather(output, "hidden_states", dim=1) validator.validate() # Raises if any shard was not gathered

Attributes:

Name Type Description
_sharded set[str]

Set of tensor names that have been sharded.

_gathered set[str]

Set of tensor names that have been gathered.

_enabled bool

Whether tracking is currently enabled.

gather

gather(tensor: Tensor, name: str, dim: int) -> Tensor

Gather a tensor and track the operation.

Parameters:

Name Type Description Default
tensor Tensor

The local shard to gather.

required
name str

The name used when sharding (for validation).

required
dim int

The dimension along which to gather.

required

Returns:

Type Description
Tensor

The gathered tensor.

reset

reset() -> None

Reset the validator state for a new forward pass.

shard

shard(
    tensor: Tensor,
    name: str,
    dim: int,
    validate_divisible: bool = True,
) -> Tensor

Shard a tensor and track the operation.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to shard.

required
name str

A name to identify this tensor for validation.

required
dim int

The dimension along which to split.

required
validate_divisible bool

If True, validate divisibility.

True

Returns:

Type Description
Tensor

The sharded tensor.

track

track()

Context manager to enable tracking for a forward pass.

validate

validate() -> None

Validate that all sharded tensors were gathered.

Raises:

Type Description
ValueError

If any sharded tensor was not gathered.

get_sharding_validator

get_sharding_validator() -> ShardingValidator

Get the global sharding validator instance.

Returns:

Type Description
ShardingValidator

The global ShardingValidator.

sp_gather

sp_gather(
    tensor: Tensor, dim: int, validate: bool = True
) -> Tensor

Gather a tensor along the specified dimension from all sequence parallel ranks.

The sharded tensors from all ranks are concatenated along dim.

Parameters:

Name Type Description Default
tensor Tensor

The local shard to gather.

required
dim int

The dimension along which to gather.

required
validate bool

If True, validate tensor consistency (currently unused).

True

Returns:

Type Description
Tensor

The full tensor gathered from all ranks.

Example

At end of model forward:

output = sp_gather(output, dim=1)

sp_shard

sp_shard(
    tensor: Tensor, dim: int, validate: bool = True
) -> Tensor

Shard a tensor along the specified dimension for sequence parallelism.

The tensor is split into world_size chunks along dim, and this rank receives its corresponding chunk.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to shard.

required
dim int

The dimension along which to split.

required
validate bool

If True, validate that the tensor size is divisible by world_size.

True

Returns:

Type Description
Tensor

The shard for this rank.

Raises:

Type Description
ValueError

If validate=True and tensor size is not divisible by world_size.

Example

In model forward:

hidden_states = sp_shard(hidden_states, dim=1)

sp_shard_with_padding

sp_shard_with_padding(
    tensor: Tensor, dim: int, pad_value: float = 0.0
) -> tuple[Tensor, int]

Shard a tensor with automatic padding if not divisible by world_size.

This is useful for variable-length sequences where padding may be needed.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to shard.

required
dim int

The dimension along which to split.

required
pad_value float

Value to use for padding.

0.0

Returns:

Type Description
Tensor

Tuple of (sharded_tensor, padding_size). The padding_size indicates

int

how much padding was added to the original tensor before sharding.

Example

sharded, pad_size = sp_shard_with_padding(hidden_states, dim=1)

... process ...

output = sp_gather(output, dim=1) if pad_size > 0: output = output[..., :-pad_size] # Remove padding