vllm_omni.diffusion.distributed.comm ¶
RingComm ¶
Ring communication utility for Ring Attention P2P communication.
SeqAllToAll4D ¶
SeqAllToAll5D ¶
all_to_all_4D ¶
all_to_all_4D(
input: tensor,
scatter_idx: int = 2,
gather_idx: int = 1,
group=None,
use_sync: bool = False,
) -> tensor
all-to-all for QKV
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input | tensor | a tensor sharded along dim scatter dim | required |
scatter_idx | int | default 1 | 2 |
gather_idx | int | default 2 | 1 |
group | ProcessGroup | torch process group | None |
use_sync | bool | whether to synchronize after all-to-all | False |
Returns:
| Type | Description |
|---|---|
tensor | torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) |
all_to_all_5D ¶
all_to_all_5D(
input: tensor,
scatter_idx: int = 3,
gather_idx: int = 1,
group=None,
use_sync: bool = False,
) -> tensor
all-to-all for QKV forward (bs, seqlen/N, 3, hc, hs) -> (bs, seqlen, 3, hc/N, hs)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input | tensor | a tensor sharded along dim scatter dim | required |
scatter_idx | int | default 1 | 3 |
gather_idx | int | default 2 | 1 |
group | ProcessGroup | torch process group | None |
use_sync | bool | whether to synchronize after all-to-all | False |
Returns:
| Type | Description |
|---|---|
tensor | torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) |