Skip to content

vllm_omni.distributed.omni_connectors.utils.kv_utils

Utility helpers for KV cache manipulation, TP routing, and merge/slice.

LayerKV module-attribute

LayerKV = Tensor | tuple[Tensor, Tensor]

logger module-attribute

logger = init_logger(__name__)

KVTPTopology dataclass

Immutable descriptor for a KV-transfer parallel mapping.

Captures sender/receiver parallel sizes and the local rank within that parallel dimension. Works for any divisible parallel dimension (TP, SP, Ring Attention).

is_heterogeneous property

is_heterogeneous: bool

local_rank instance-attribute

local_rank: int

ratio property

ratio: int

Larger parallel size divided by smaller. Always >= 1.

source_tp_size instance-attribute

source_tp_size: int

target_tp_size instance-attribute

target_tp_size: int

build_rank_aware_recv_keys

build_rank_aware_recv_keys(
    request_id: str,
    from_stage: str,
    to_stage: str,
    topo: KVTPTopology,
    hook: Callable | None = None,
) -> list[tuple[str, int | None]]

Build recv-side connector keys with sender rank info.

Returns a list of (key, from_rank) tuples. from_rank is None when TP <= 1 (single sender, no per-rank routing needed). For TP > 1, from_rank identifies which sender rank owns the key so that the connector can route metadata queries to the correct endpoint.

build_rank_aware_send_keys

build_rank_aware_send_keys(
    request_id: str,
    from_stage: str,
    to_stage: str,
    topo: KVTPTopology,
    hook: Callable | None = None,
) -> list[str]

Build send-side connector keys, checking injectable hook first.

get_kv_connector_key

get_kv_connector_key(
    req_id: str,
    from_stage: int | str,
    chunk_id: int,
    from_rank: int,
    to_rank: int,
) -> str

Build connector key that includes rank info for KV transfers.

Format matches PR #2677: {req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}

get_kv_source_ranks

get_kv_source_ranks(topo: KVTPTopology) -> list[int]

Which remote ranks this local rank receives KV shards from (recv side).

get_kv_target_ranks

get_kv_target_ranks(topo: KVTPTopology) -> list[int]

Which remote ranks this local rank sends KV shards to (send side).

get_local_tp_rank

get_local_tp_rank() -> int

Return the TP-local rank of this worker process.

Uses get_tensor_model_parallel_rank() which returns the rank within the TP group only, not the stage-global rank.

get_omni_replica_id

get_omni_replica_id() -> int

Return the Omni replica id for this worker process.

get_tp_world_size

get_tp_world_size() -> int

Return the TP world size (tensor-parallel dimension only).

Uses get_tensor_model_parallel_world_size() so that cfg_parallel, SP, PP etc. are not included in the count.

kv_zmq_port

kv_zmq_port(
    base_port: int,
    from_stage: int,
    local_rank: int = 0,
    replica_id: int | None = None,
) -> int

Compute the ZMQ port for a KV-transfer connector.

Each Omni replica and TP rank gets its own port so multi-replica or TP > 1 deployments do not cause EADDRINUSE when multiple sender workers bind on the same host. The formula is backward-compatible: replica 0 / rank 0 produces the previous base + OFFSET + stage port.

merge_received_rank_shards

merge_received_rank_shards(
    payloads: list[dict[str, Any]],
    merger: Callable | None = None,
) -> dict[str, Any] | None

Merge multiple source-rank KV shards for one target rank.

When merger is provided (injectable hook), it is called directly. Otherwise the default merges along the head dimension (dim 1).

normalize_layer_kv

normalize_layer_kv(
    layer_kv: LayerKV,
    *,
    req_id: str = "",
    layer_idx: int = -1,
) -> tuple[Tensor, Tensor] | None

Normalize one layer KV cache to a (key_blocks, value_blocks) tuple.

In vLLM, different attention backends return paged-attention KV blocks with different layouts. For example:

  • FlashAttention (vllm/v1/attention/backends/flash_attn.py) returns shape (2, num_blocks, block_size, num_kv_heads, head_size) – the key/value dimension is at dim 0.
  • FlashInfer (vllm/v1/attention/backends/flashinfer.py) returns shape (num_blocks, 2, block_size, num_kv_heads, head_size) – the key/value dimension is at dim 1.

This utility handles both layouts (and the tuple case) so that downstream code can work with any backend.

Supported layouts:

  • Stacked tensor [2, num_blocks, block_size, n_heads, head_dim] – dim-0 selects key / value.
  • Stacked tensor [num_blocks, 2, block_size, n_heads, head_dim] – dim-1 selects key / value.
  • Tuple (key_tensor, value_tensor) – returned as-is after validation.

Parameters:

Name Type Description Default
layer_kv LayerKV

The raw KV cache (tensor or tuple) for the layer.

required
req_id str

Request ID used only for diagnostic log messages.

''
layer_idx int

Layer index used only for diagnostic log messages.

-1

Returns:

Type Description
tuple[Tensor, Tensor] | None

(key_blocks, value_blocks) if layer_kv is valid, None

tuple[Tensor, Tensor] | None

otherwise.

slice_kv_tensor_heads

slice_kv_tensor_heads(
    tensor: Tensor | None,
    offset_in_shard: int,
    num_slices: int,
) -> Tensor | None

Slice one KV tensor along its head dimension (dim 1).

slice_layer_blocks

slice_layer_blocks(
    layer_blocks: dict[str, Any],
    offset_in_shard: int,
    num_slices: int,
) -> dict[str, list[Tensor | None]]

Slice all KV layers for one logical receiver rank.

slice_received_rank_shard

slice_received_rank_shard(
    payload: dict[str, Any] | None,
    topo: KVTPTopology,
    slicer: Callable | None = None,
) -> dict[str, Any] | None

Optionally slice a received payload to extract this rank's portion.

Used when to_tp > from_tp: the sender sent full heads and each receiver rank slices out its own subset.

validate_kv_tp_topology

validate_kv_tp_topology(topo: KVTPTopology) -> None

Reject heterogeneous TP mappings that cannot be routed losslessly.