vllm_omni.distributed.omni_connectors.utils.kv_utils ¶
Utility helpers for KV cache manipulation, TP routing, and merge/slice.
KVTPTopology dataclass ¶
Immutable descriptor for a KV-transfer parallel mapping.
Captures sender/receiver parallel sizes and the local rank within that parallel dimension. Works for any divisible parallel dimension (TP, SP, Ring Attention).
build_rank_aware_recv_keys ¶
build_rank_aware_recv_keys(
request_id: str,
from_stage: str,
to_stage: str,
topo: KVTPTopology,
hook: Callable | None = None,
) -> list[tuple[str, int | None]]
Build recv-side connector keys with sender rank info.
Returns a list of (key, from_rank) tuples. from_rank is None when TP <= 1 (single sender, no per-rank routing needed). For TP > 1, from_rank identifies which sender rank owns the key so that the connector can route metadata queries to the correct endpoint.
build_rank_aware_send_keys ¶
build_rank_aware_send_keys(
request_id: str,
from_stage: str,
to_stage: str,
topo: KVTPTopology,
hook: Callable | None = None,
) -> list[str]
Build send-side connector keys, checking injectable hook first.
get_kv_connector_key ¶
get_kv_connector_key(
req_id: str,
from_stage: int | str,
chunk_id: int,
from_rank: int,
to_rank: int,
) -> str
Build connector key that includes rank info for KV transfers.
Format matches PR #2677: {req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}
get_kv_source_ranks ¶
get_kv_source_ranks(topo: KVTPTopology) -> list[int]
Which remote ranks this local rank receives KV shards from (recv side).
get_kv_target_ranks ¶
get_kv_target_ranks(topo: KVTPTopology) -> list[int]
Which remote ranks this local rank sends KV shards to (send side).
get_local_tp_rank ¶
get_local_tp_rank() -> int
Return the TP-local rank of this worker process.
Uses get_tensor_model_parallel_rank() which returns the rank within the TP group only, not the stage-global rank.
get_omni_replica_id ¶
get_omni_replica_id() -> int
Return the Omni replica id for this worker process.
get_tp_world_size ¶
get_tp_world_size() -> int
Return the TP world size (tensor-parallel dimension only).
Uses get_tensor_model_parallel_world_size() so that cfg_parallel, SP, PP etc. are not included in the count.
kv_zmq_port ¶
kv_zmq_port(
base_port: int,
from_stage: int,
local_rank: int = 0,
replica_id: int | None = None,
) -> int
Compute the ZMQ port for a KV-transfer connector.
Each Omni replica and TP rank gets its own port so multi-replica or TP > 1 deployments do not cause EADDRINUSE when multiple sender workers bind on the same host. The formula is backward-compatible: replica 0 / rank 0 produces the previous base + OFFSET + stage port.
merge_received_rank_shards ¶
merge_received_rank_shards(
payloads: list[dict[str, Any]],
merger: Callable | None = None,
) -> dict[str, Any] | None
Merge multiple source-rank KV shards for one target rank.
When merger is provided (injectable hook), it is called directly. Otherwise the default merges along the head dimension (dim 1).
normalize_layer_kv ¶
normalize_layer_kv(
layer_kv: LayerKV,
*,
req_id: str = "",
layer_idx: int = -1,
) -> tuple[Tensor, Tensor] | None
Normalize one layer KV cache to a (key_blocks, value_blocks) tuple.
In vLLM, different attention backends return paged-attention KV blocks with different layouts. For example:
- FlashAttention (
vllm/v1/attention/backends/flash_attn.py) returns shape(2, num_blocks, block_size, num_kv_heads, head_size)– the key/value dimension is at dim 0. - FlashInfer (
vllm/v1/attention/backends/flashinfer.py) returns shape(num_blocks, 2, block_size, num_kv_heads, head_size)– the key/value dimension is at dim 1.
This utility handles both layouts (and the tuple case) so that downstream code can work with any backend.
Supported layouts:
- Stacked tensor
[2, num_blocks, block_size, n_heads, head_dim]– dim-0 selects key / value. - Stacked tensor
[num_blocks, 2, block_size, n_heads, head_dim]– dim-1 selects key / value. - Tuple
(key_tensor, value_tensor)– returned as-is after validation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_kv | LayerKV | The raw KV cache (tensor or tuple) for the layer. | required |
req_id | str | Request ID used only for diagnostic log messages. | '' |
layer_idx | int | Layer index used only for diagnostic log messages. | -1 |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor] | None |
|
tuple[Tensor, Tensor] | None | otherwise. |
slice_kv_tensor_heads ¶
slice_kv_tensor_heads(
tensor: Tensor | None,
offset_in_shard: int,
num_slices: int,
) -> Tensor | None
Slice one KV tensor along its head dimension (dim 1).
slice_layer_blocks ¶
slice_layer_blocks(
layer_blocks: dict[str, Any],
offset_in_shard: int,
num_slices: int,
) -> dict[str, list[Tensor | None]]
Slice all KV layers for one logical receiver rank.
slice_received_rank_shard ¶
slice_received_rank_shard(
payload: dict[str, Any] | None,
topo: KVTPTopology,
slicer: Callable | None = None,
) -> dict[str, Any] | None
Optionally slice a received payload to extract this rank's portion.
Used when to_tp > from_tp: the sender sent full heads and each receiver rank slices out its own subset.
validate_kv_tp_topology ¶
validate_kv_tp_topology(topo: KVTPTopology) -> None
Reject heterogeneous TP mappings that cannot be routed losslessly.