Skip to content

vllm_omni.diffusion.distributed.vae_patch_parallel

Distributed VAE patch/tile parallelism utilities.

logger module-attribute

logger = init_logger(__name__)

VaePatchParallelism

Patch/tile-parallel VAE decode wrapper.

This is meant to wrap vae.decode as an instance-level override so pipelines don't need model-specific code paths.

decode

decode(
    z: Tensor,
    return_dict: bool = True,
    *args: Any,
    **kwargs: Any,
)

maybe_wrap_vae_decode_with_patch_parallelism

maybe_wrap_vae_decode_with_patch_parallelism(
    pipeline: Any,
    *,
    vae_patch_parallel_size: int,
    group_getter: Callable[[], ProcessGroup],
) -> None

Wrap a diffusers-style pipeline's vae.decode with patch/tile parallel decode.