Skip to content

vllm_omni.diffusion.distributed.comm

RingComm

Ring communication utility for Ring Attention P2P communication.

rank instance-attribute

rank = get_rank(_process_group)

recv_rank instance-attribute

recv_rank = (rank - 1) % world_size

send_rank instance-attribute

send_rank = (rank + 1) % world_size

world_size instance-attribute

world_size = get_world_size(_process_group)

commit

commit()

send_recv

send_recv(
    to_send: Tensor, recv_tensor: Tensor | None = None
) -> Tensor

wait

wait()

SeqAllToAll4D

Bases: Function

forward staticmethod

forward(
    ctx: Any,
    group: ProcessGroup,
    input: Tensor,
    scatter_idx: int,
    gather_idx: int,
    use_sync: bool = False,
) -> Tensor

SeqAllToAll5D

Bases: Function

forward staticmethod

forward(
    ctx: Any,
    group: ProcessGroup,
    input: Tensor,
    scatter_idx: int = 3,
    gather_idx: int = 1,
    use_sync: bool = False,
) -> Tensor

all_to_all_4D

all_to_all_4D(
    input: tensor,
    scatter_idx: int = 2,
    gather_idx: int = 1,
    group=None,
    use_sync: bool = False,
) -> tensor

all-to-all for QKV

Parameters:

Name Type Description Default
input tensor

a tensor sharded along dim scatter dim

required
scatter_idx int

default 1

2
gather_idx int

default 2

1
group ProcessGroup

torch process group

None
use_sync bool

whether to synchronize after all-to-all

False

Returns:

Type Description
tensor

torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)

all_to_all_5D

all_to_all_5D(
    input: tensor,
    scatter_idx: int = 3,
    gather_idx: int = 1,
    group=None,
    use_sync: bool = False,
) -> tensor

all-to-all for QKV forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs)

Parameters:

Name Type Description Default
input tensor

a tensor sharded along dim scatter dim

required
scatter_idx int

default 1

3
gather_idx int

default 2

1
group ProcessGroup

torch process group

None
use_sync bool

whether to synchronize after all-to-all

False

Returns:

Type Description
tensor

torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs)