vllm_omni.diffusion.distributed.sp_sharding ¶
Sequence Parallelism sharding utilities.
This module provides low-level sharding and gathering functions for Sequence Parallelism. These can be used directly in model forward methods for semi-intrusive SP support, or internally by the SP hooks.
ShardingValidator dataclass ¶
Validator for tracking and verifying sharding operations.
This class helps ensure that sharding and gathering operations are correctly paired in model forward passes. It tracks which tensors have been sharded and verifies that they are properly gathered.
Usage
validator = ShardingValidator() with validator.track(): hidden_states = validator.shard(hidden_states, "hidden_states", dim=1) # ... model computation ... output = validator.gather(output, "hidden_states", dim=1) validator.validate() # Raises if any shard was not gathered
Attributes:
| Name | Type | Description |
|---|---|---|
_sharded | set[str] | Set of tensor names that have been sharded. |
_gathered | set[str] | Set of tensor names that have been gathered. |
_enabled | bool | Whether tracking is currently enabled. |
gather ¶
Gather a tensor and track the operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The local shard to gather. | required |
name | str | The name used when sharding (for validation). | required |
dim | int | The dimension along which to gather. | required |
Returns:
| Type | Description |
|---|---|
Tensor | The gathered tensor. |
shard ¶
Shard a tensor and track the operation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The tensor to shard. | required |
name | str | A name to identify this tensor for validation. | required |
dim | int | The dimension along which to split. | required |
validate_divisible | bool | If True, validate divisibility. | True |
Returns:
| Type | Description |
|---|---|
Tensor | The sharded tensor. |
validate ¶
Validate that all sharded tensors were gathered.
Raises:
| Type | Description |
|---|---|
ValueError | If any sharded tensor was not gathered. |
get_sharding_validator ¶
get_sharding_validator() -> ShardingValidator
Get the global sharding validator instance.
Returns:
| Type | Description |
|---|---|
ShardingValidator | The global ShardingValidator. |
sp_gather ¶
Gather a tensor along the specified dimension from all sequence parallel ranks.
The sharded tensors from all ranks are concatenated along dim.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The local shard to gather. | required |
dim | int | The dimension along which to gather. | required |
validate | bool | If True, validate tensor consistency (currently unused). | True |
Returns:
| Type | Description |
|---|---|
Tensor | The full tensor gathered from all ranks. |
sp_shard ¶
Shard a tensor along the specified dimension for sequence parallelism.
The tensor is split into world_size chunks along dim, and this rank receives its corresponding chunk.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The tensor to shard. | required |
dim | int | The dimension along which to split. | required |
validate | bool | If True, validate that the tensor size is divisible by world_size. | True |
Returns:
| Type | Description |
|---|---|
Tensor | The shard for this rank. |
Raises:
| Type | Description |
|---|---|
ValueError | If validate=True and tensor size is not divisible by world_size. |
sp_shard_with_padding ¶
Shard a tensor with automatic padding if not divisible by world_size.
This is useful for variable-length sequences where padding may be needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor | Tensor | The tensor to shard. | required |
dim | int | The dimension along which to split. | required |
pad_value | float | Value to use for padding. | 0.0 |
Returns:
| Type | Description |
|---|---|
Tensor | Tuple of (sharded_tensor, padding_size). The padding_size indicates |
int | how much padding was added to the original tensor before sharding. |
Example
sharded, pad_size = sp_shard_with_padding(hidden_states, dim=1)
... process ...¶
output = sp_gather(output, dim=1) if pad_size > 0: output = output[..., :-pad_size] # Remove padding