Skip to content

vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor

logger module-attribute

logger = init_logger(__name__)

DistributedOperator dataclass

exec instance-attribute

exec: callable

merge instance-attribute

merge: callable

split instance-attribute

split: callable

DistributedVaeExecutor

Abstract util class for distributed patch/tile parallel VAE decoding.

group instance-attribute

group = get_dit_group()

parallel_size instance-attribute

parallel_size = 1

rank instance-attribute

rank = get_rank(group)

world_size instance-attribute

world_size = get_world_size(group)

broadcast_tensor

broadcast_tensor(tensor: Tensor)

execute

execute(
    z: Tensor,
    operator: DistributedOperator,
    broadcast_result: bool = True,
)

gather_tensors

gather_tensors(tensor: Tensor)

set_parallel_size

set_parallel_size(parallel_size: int)

DistributedVaeMixin

init_distributed

init_distributed()

is_distributed_enabled

is_distributed_enabled() -> bool

set_parallel_size

set_parallel_size(parallel_size: int) -> None

GridSpec dataclass

The Grid shape split

grid_shape instance-attribute

grid_shape: tuple[int, ...]

output_dtype class-attribute instance-attribute

output_dtype: dtype | None = None

split_dims instance-attribute

split_dims: tuple[int, ...]

tile_spec class-attribute instance-attribute

tile_spec: dict = field(default_factory=dict)

TileTask dataclass

grid_coord instance-attribute

grid_coord: tuple[int, ...]

tensor instance-attribute

tensor: Tensor | list[Tensor]

tile_id instance-attribute

tile_id: int

workload class-attribute instance-attribute

workload: int | float = 1