Skip to content

vllm_omni.diffusion.distributed.hsdp

logger module-attribute

logger = init_logger(__name__)

HSDPInferenceConfig dataclass

Configuration for HSDP inference.

This is a runtime config created from DiffusionParallelConfig's HSDP settings.

enabled class-attribute instance-attribute

enabled: bool = False

hsdp_replicate_size class-attribute instance-attribute

hsdp_replicate_size: int = 1

hsdp_shard_size class-attribute instance-attribute

hsdp_shard_size: int = -1

output_dtype class-attribute instance-attribute

output_dtype: dtype | None = None

param_dtype class-attribute instance-attribute

param_dtype: dtype = bfloat16

reduce_dtype class-attribute instance-attribute

reduce_dtype: dtype = float32

reshard_after_forward class-attribute instance-attribute

reshard_after_forward: bool = True

apply_hsdp_to_model

apply_hsdp_to_model(
    model: Module, hsdp_config: HSDPInferenceConfig
) -> Module

Apply HSDP sharding to a model that already has weights loaded.

This function redistributes the model's parameters across GPUs using HSDP. The model should already have its weights loaded via the standard load_weights method.

Parameters:

Name Type Description Default
model Module

Model instance with weights already loaded

required
hsdp_config HSDPInferenceConfig

HSDP configuration with HSDP mesh dimensions

required

Returns:

Type Description
Module

HSDP-wrapped model ready for inference

shard_model

shard_model(
    model: Module,
    *,
    reshard_after_forward: bool = True,
    mp_policy: MixedPrecisionPolicy | None = None,
    mesh: DeviceMesh | None = None,
    hsdp_shard_conditions: list[
        Callable[[str, Module], bool]
    ],
) -> None

Apply HSDP sharding to model modules based on shard conditions.