Skip to content

vllm_gaudi.extension.unified

BlocksT module-attribute

BlocksT: TypeAlias = Union[tensor, int]

CacheUtils

Helper utilities for kv-cache

Parameters:

Name Type Description Default
is_mla

If True, cache stores MLA latent vectors (no head dimension, single cache). If False, standard attention with per-head K/V caches.

False
Source code in vllm_gaudi/extension/unified.py
class CacheUtils:
    """Helper utilities for kv-cache

    Args:
        is_mla: If True, cache stores MLA latent vectors (no head dimension, single cache).
                If False, standard attention with per-head K/V caches.
    """

    def __init__(self, key_cache, value_cache, block_size, k_scales=None, v_scales=None, is_mla=False):
        self.key_cache = key_cache
        self.value_cache = value_cache
        self.block_size = block_size
        self.is_mla = is_mla

        # MLA stores latent vectors in a single cache
        if is_mla:
            assert value_cache is None, "MLA mode requires value_cache=None (latent stored in key_cache)"

        self.kv_heads = 1 if is_mla else key_cache.size(1)
        self.k_scales = k_scales
        self.v_scales = v_scales

    def fetch_shared(self, blocks: BlocksT) -> torch.tensor:
        """Fetch selected shared blocks"""
        return self._fetch_all(self._fetch_single_shared, blocks)

    def fetch_unique(self, blocks: BlocksT) -> torch.tensor:
        """Fetch selected unique blocks"""
        return self._fetch_all(self._fetch_single_unique, blocks)

    def _fetch_all(self, fn: Callable[[torch.tensor, BlocksT], torch.tensor],
                   blocks: BlocksT) -> tuple[torch.tensor, torch.tensor]:
        """Fetch both key and values using selected function"""
        if self.value_cache is None:
            return fn(self.key_cache, blocks)
        return fn(self.key_cache, blocks), fn(self.value_cache, blocks)

    def _fetch_single_shared(self, cache: torch.tensor, blocks: BlocksT) -> torch.tensor:
        """Fetch selected shared blocks from given cache"""
        result = cache.unflatten(0, (-1, self.block_size)).index_select(0, blocks).flatten(0, 1)
        if not self.is_mla:
            result = result.transpose(0, 1).unflatten(0, (self.kv_heads, -1))
        return result

    def _fetch_single_unique(self, cache: torch.tensor, blocks: BlocksT) -> torch.tensor:
        """Fetch selected unique blocks from given cache"""
        cache = cache.unflatten(0, (-1, self.block_size))
        if not self.is_mla:
            cache = cache.transpose(1, 2)

        if torch.is_tensor(blocks):
            result = cache.index_select(0, blocks)
        elif type(blocks) == int:
            result = cache[:blocks]
        else:
            raise RuntimeError(f'Unsupported type for blocks: {type(blocks)}')

        if not self.is_mla:
            result = result.unflatten(1, (self.kv_heads, -1))
        else:
            result = result.flatten(0, 1)
        return result

block_size instance-attribute

block_size = block_size

is_mla instance-attribute

is_mla = is_mla

k_scales instance-attribute

k_scales = k_scales

key_cache instance-attribute

key_cache = key_cache

kv_heads instance-attribute

kv_heads = 1 if is_mla else size(1)

v_scales instance-attribute

v_scales = v_scales

value_cache instance-attribute

value_cache = value_cache

__init__

__init__(
    key_cache,
    value_cache,
    block_size,
    k_scales=None,
    v_scales=None,
    is_mla=False,
)
Source code in vllm_gaudi/extension/unified.py
def __init__(self, key_cache, value_cache, block_size, k_scales=None, v_scales=None, is_mla=False):
    self.key_cache = key_cache
    self.value_cache = value_cache
    self.block_size = block_size
    self.is_mla = is_mla

    # MLA stores latent vectors in a single cache
    if is_mla:
        assert value_cache is None, "MLA mode requires value_cache=None (latent stored in key_cache)"

    self.kv_heads = 1 if is_mla else key_cache.size(1)
    self.k_scales = k_scales
    self.v_scales = v_scales

_fetch_all

_fetch_all(
    fn: Callable[[tensor, BlocksT], tensor], blocks: BlocksT
) -> tuple[tensor, tensor]

Fetch both key and values using selected function

Source code in vllm_gaudi/extension/unified.py
def _fetch_all(self, fn: Callable[[torch.tensor, BlocksT], torch.tensor],
               blocks: BlocksT) -> tuple[torch.tensor, torch.tensor]:
    """Fetch both key and values using selected function"""
    if self.value_cache is None:
        return fn(self.key_cache, blocks)
    return fn(self.key_cache, blocks), fn(self.value_cache, blocks)

_fetch_single_shared

_fetch_single_shared(
    cache: tensor, blocks: BlocksT
) -> tensor

Fetch selected shared blocks from given cache

Source code in vllm_gaudi/extension/unified.py
def _fetch_single_shared(self, cache: torch.tensor, blocks: BlocksT) -> torch.tensor:
    """Fetch selected shared blocks from given cache"""
    result = cache.unflatten(0, (-1, self.block_size)).index_select(0, blocks).flatten(0, 1)
    if not self.is_mla:
        result = result.transpose(0, 1).unflatten(0, (self.kv_heads, -1))
    return result

_fetch_single_unique

_fetch_single_unique(
    cache: tensor, blocks: BlocksT
) -> tensor

Fetch selected unique blocks from given cache

Source code in vllm_gaudi/extension/unified.py
def _fetch_single_unique(self, cache: torch.tensor, blocks: BlocksT) -> torch.tensor:
    """Fetch selected unique blocks from given cache"""
    cache = cache.unflatten(0, (-1, self.block_size))
    if not self.is_mla:
        cache = cache.transpose(1, 2)

    if torch.is_tensor(blocks):
        result = cache.index_select(0, blocks)
    elif type(blocks) == int:
        result = cache[:blocks]
    else:
        raise RuntimeError(f'Unsupported type for blocks: {type(blocks)}')

    if not self.is_mla:
        result = result.unflatten(1, (self.kv_heads, -1))
    else:
        result = result.flatten(0, 1)
    return result

fetch_shared

fetch_shared(blocks: BlocksT) -> tensor

Fetch selected shared blocks

Source code in vllm_gaudi/extension/unified.py
def fetch_shared(self, blocks: BlocksT) -> torch.tensor:
    """Fetch selected shared blocks"""
    return self._fetch_all(self._fetch_single_shared, blocks)

fetch_unique

fetch_unique(blocks: BlocksT) -> tensor

Fetch selected unique blocks

Source code in vllm_gaudi/extension/unified.py
def fetch_unique(self, blocks: BlocksT) -> torch.tensor:
    """Fetch selected unique blocks"""
    return self._fetch_all(self._fetch_single_unique, blocks)

HPUUnifiedAttentionMetadata dataclass

Source code in vllm_gaudi/extension/unified.py
@dataclass
class HPUUnifiedAttentionMetadata:
    block_size: int
    slot_mapping: torch.tensor
    causal_bias: Optional[torch.tensor]
    causal_width: int
    shared_blocks: Optional[torch.tensor]
    shared_bias: Optional[torch.tensor]
    # Chunked bias data for chunk-wise computation (used when shared_bias is None but shared_blocks exists)
    shared_bias_chunked: Optional[SharedBlockChunkedBiasData]
    shared_chunk_size: int  # Number of blocks to process per chunk (0 = use full bias)
    unique_blocks: Optional[torch.tensor] | Optional[int]
    unique_block_mapping: Optional[torch.tensor]
    unique_bias: Optional[torch.tensor]
    fmin: torch.tensor
    feps: torch.tensor
    inputL_hpu_tensors: Optional[Dict[tuple, torch.Tensor]]
    inputM_hpu_tensors: Optional[Dict[tuple, torch.Tensor]]
    online_merge: bool
    split_graphs: bool

    def seq_len(self):
        # TODO: This needs to be changed in case of mixed batches
        return self.slot_mapping.size(-1) if self.causal_bias is not None else 1

    def num_blocks(self):
        result = 0
        if self.shared_blocks is not None:
            result += self.shared_blocks.size(-1)
        if self.unique_blocks is not None:
            if torch.is_tensor(self.unique_blocks):
                result += self.unique_blocks.size(-1)
            else:
                result += self.unique_blocks
        return result

    @property
    def is_prompt(self):
        return self.causal_bias is not None

block_size instance-attribute

block_size: int

causal_bias instance-attribute

causal_bias: Optional[tensor]

causal_width instance-attribute

causal_width: int

feps instance-attribute

feps: tensor

fmin instance-attribute

fmin: tensor

inputL_hpu_tensors instance-attribute

inputL_hpu_tensors: Optional[Dict[tuple, Tensor]]

inputM_hpu_tensors instance-attribute

inputM_hpu_tensors: Optional[Dict[tuple, Tensor]]

is_prompt property

is_prompt

online_merge instance-attribute

online_merge: bool

shared_bias instance-attribute

shared_bias: Optional[tensor]

shared_bias_chunked instance-attribute

shared_bias_chunked: Optional[SharedBlockChunkedBiasData]

shared_blocks instance-attribute

shared_blocks: Optional[tensor]

shared_chunk_size instance-attribute

shared_chunk_size: int

slot_mapping instance-attribute

slot_mapping: tensor

split_graphs instance-attribute

split_graphs: bool

unique_bias instance-attribute

unique_bias: Optional[tensor]

unique_block_mapping instance-attribute

unique_block_mapping: Optional[tensor]

unique_blocks instance-attribute

unique_blocks: Optional[tensor] | Optional[int]

__init__

__init__(
    block_size: int,
    slot_mapping: tensor,
    causal_bias: Optional[tensor],
    causal_width: int,
    shared_blocks: Optional[tensor],
    shared_bias: Optional[tensor],
    shared_bias_chunked: Optional[
        SharedBlockChunkedBiasData
    ],
    shared_chunk_size: int,
    unique_blocks: Optional[tensor] | Optional[int],
    unique_block_mapping: Optional[tensor],
    unique_bias: Optional[tensor],
    fmin: tensor,
    feps: tensor,
    inputL_hpu_tensors: Optional[Dict[tuple, Tensor]],
    inputM_hpu_tensors: Optional[Dict[tuple, Tensor]],
    online_merge: bool,
    split_graphs: bool,
) -> None

num_blocks

num_blocks()
Source code in vllm_gaudi/extension/unified.py
def num_blocks(self):
    result = 0
    if self.shared_blocks is not None:
        result += self.shared_blocks.size(-1)
    if self.unique_blocks is not None:
        if torch.is_tensor(self.unique_blocks):
            result += self.unique_blocks.size(-1)
        else:
            result += self.unique_blocks
    return result

seq_len

seq_len()
Source code in vllm_gaudi/extension/unified.py
def seq_len(self):
    # TODO: This needs to be changed in case of mixed batches
    return self.slot_mapping.size(-1) if self.causal_bias is not None else 1

SharedBlockChunkedBiasData dataclass

Data needed to compute shared block bias per-chunk during chunked attention.

This avoids materializing the full [query_len, num_shared_blocks, block_size] bias tensor which can be prohibitively large with many shared blocks.

Contains dense block_usages of shape (num_query_tokens, num_shared_blocks). During chunked attention, we slice block_usages[:, chunk_start:chunk_end] and generate bias for each chunk on-the-fly.

Source code in vllm_gaudi/extension/unified.py
@dataclass
class SharedBlockChunkedBiasData:
    """Data needed to compute shared block bias per-chunk during chunked attention.

    This avoids materializing the full [query_len, num_shared_blocks, block_size] 
    bias tensor which can be prohibitively large with many shared blocks.

    Contains dense block_usages of shape (num_query_tokens, num_shared_blocks).
    During chunked attention, we slice block_usages[:, chunk_start:chunk_end] and
    generate bias for each chunk on-the-fly.
    """
    block_usages: torch.tensor  # Dense: [num_query_tokens, num_shared_blocks]
    num_query_tokens: int  # Total query length (padded)
    num_shared_blocks: int  # Total number of shared blocks (padded)
    split_chunked_graphs: bool

block_usages instance-attribute

block_usages: tensor

num_query_tokens instance-attribute

num_query_tokens: int

num_shared_blocks instance-attribute

num_shared_blocks: int

split_chunked_graphs instance-attribute

split_chunked_graphs: bool

__init__

__init__(
    block_usages: tensor,
    num_query_tokens: int,
    num_shared_blocks: int,
    split_chunked_graphs: bool,
) -> None

_partial_attn_shared_chunked

_partial_attn_shared_chunked(
    query: tensor,
    blocks: tensor,
    bias: Optional[tensor],
    chunked_data: SharedBlockChunkedBiasData,
    chunk_size: int,
    fmin: tensor,
    inputL_hpu_tensors: Dict[tuple, Tensor],
    inputM_hpu_tensors: Dict[tuple, Tensor],
    cache_utils: CacheUtils,
    dtype: dtype,
    w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]

Chunked implementation of partial_attn_shared with per-chunk bias generation.

Generates bias per chunk from dense block_usages to save memory. Avoids materializing the full (query_len, num_blocks, block_size) bias tensor.

Strategy: 1. Process blocks in chunks of chunk_size 2. For each chunk, slice block_usages and generate chunk bias on-the-fly 3. Compute attention for the chunk using _partial_attn_shared_core 4. Merge chunk results using flash-attention style online softmax

Source code in vllm_gaudi/extension/unified.py
def _partial_attn_shared_chunked(
        query: torch.tensor,
        blocks: torch.tensor,
        bias: Optional[torch.tensor],
        chunked_data: SharedBlockChunkedBiasData,
        chunk_size: int,
        fmin: torch.tensor,
        inputL_hpu_tensors: Dict[tuple, torch.Tensor],
        inputM_hpu_tensors: Dict[tuple, torch.Tensor],
        cache_utils: CacheUtils,
        dtype: torch.dtype,
        w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """Chunked implementation of partial_attn_shared with per-chunk bias generation.

    Generates bias per chunk from dense block_usages to save memory.
    Avoids materializing the full (query_len, num_blocks, block_size) bias tensor.

    Strategy:
    1. Process blocks in chunks of chunk_size
    2. For each chunk, slice block_usages and generate chunk bias on-the-fly
    3. Compute attention for the chunk using _partial_attn_shared_core
    4. Merge chunk results using flash-attention style online softmax
    """
    num_blocks = chunked_data.num_shared_blocks
    block_size = cache_utils.block_size
    num_query_tokens = chunked_data.num_query_tokens

    is_mla = w_uv is not None
    kv_heads = 1 if is_mla else cache_utils.kv_heads

    # Calculate number of chunks
    num_chunks = math.ceil(num_blocks / chunk_size)

    # Check if we have pre-computed bias or need to generate per-chunk
    generate_bias_per_chunk = (bias is None)

    # Pre-allocate reusable tensors outside the loop (avoid allocations per iteration)
    if generate_bias_per_chunk:
        block_len_range = torch.arange(1,
                                       block_size + 1,
                                       dtype=chunked_data.block_usages.dtype,
                                       device=chunked_data.block_usages.device)
        # Pre-allocate chunk_bias buffer - will be overwritten each iteration
        chunk_bias_buffer = torch.empty((num_query_tokens, chunk_size, block_size),
                                        dtype=dtype,
                                        device=chunked_data.block_usages.device)

    # Accumulators for online softmax-style merging
    accumulated_attn = None
    global_max = None
    global_sum = None
    split_graphs = chunked_data.split_chunked_graphs
    for chunk_idx in range(num_chunks):
        if split_graphs:
            htorch.core.mark_step()
        chunk_start = chunk_idx * chunk_size
        chunk_end = min(chunk_start + chunk_size, num_blocks)
        actual_chunk_len = chunk_end - chunk_start

        # Slice blocks for this chunk
        chunk_blocks = blocks[chunk_start:chunk_end]

        if generate_bias_per_chunk:
            # Generate bias for this chunk from dense block_usages
            # chunked_data.block_usages is (num_query_tokens, num_shared_blocks)
            # Slice to get (num_query_tokens, actual_chunk_len) for this chunk
            chunk_block_usages = chunked_data.block_usages[:, chunk_start:chunk_end]

            # Generate chunk bias using dense broadcast into pre-allocated buffer
            # chunk_block_usages.unsqueeze(-1): (num_query_tokens, actual_chunk_len, 1)
            # broadcast comparison: (num_query_tokens, actual_chunk_len, block_size)
            chunk_mask = block_len_range > chunk_block_usages.unsqueeze(-1)

            # Use view of pre-allocated buffer for actual chunk size
            chunk_bias = chunk_bias_buffer[:, :actual_chunk_len, :]
            chunk_bias.zero_()
            chunk_bias.masked_fill_(chunk_mask, -math.inf)
        else:
            # Pre-computed: slice from full bias tensor
            chunk_bias = bias[:, chunk_start:chunk_end, :]

        # Fetch KV for this chunk
        if is_mla:
            latent_kv = cache_utils.fetch_shared(chunk_blocks)
            num_heads = query.size(1)
            query_t = query.transpose(0, 1).unsqueeze(1)
            key = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1)
            value = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1)
        else:
            query_t = query.transpose(0, 1).unflatten(0, (kv_heads, -1))
            key, value = cache_utils.fetch_shared(chunk_blocks)

        # Flatten bias for attention: [1, query_len, chunk_len * block_size]
        chunk_bias_flat = chunk_bias.flatten(-2, -1).unsqueeze(0)

        # Compute attention for this chunk
        chunk_attn, chunk_max, chunk_sum = _partial_attn_shared_core(query_t, key, value, chunk_bias_flat, fmin,
                                                                     inputL_hpu_tensors, inputM_hpu_tensors, kv_heads,
                                                                     is_mla, w_uv)

        # Online merge: combine this chunk with accumulated results
        if accumulated_attn is None:
            # First chunk - just store
            accumulated_attn = chunk_attn
            global_max = chunk_max
            global_sum = chunk_sum
        else:
            # Merge with existing - use flash-attention style rescaling
            new_max = torch.maximum(global_max, chunk_max)

            # Rescale factors
            old_scale = torch.exp(global_max - new_max)
            new_scale = torch.exp(chunk_max - new_max)

            # Rescale accumulated values and sums
            accumulated_attn = accumulated_attn * old_scale.unsqueeze(-1) + chunk_attn * new_scale.unsqueeze(-1)
            global_sum = global_sum * old_scale + chunk_sum * new_scale
            global_max = new_max

        if split_graphs:
            htorch.core.mark_step()

    if accumulated_attn is None:
        return (None, None, None)

    return accumulated_attn, global_max, global_sum

_partial_attn_shared_core

_partial_attn_shared_core(
    query: tensor,
    key: tensor,
    value: tensor,
    bias: tensor,
    fmin: tensor,
    inputL_hpu_tensors: Dict[tuple, Tensor],
    inputM_hpu_tensors: Dict[tuple, Tensor],
    kv_heads: int,
    is_mla: bool,
    w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]

Core shared attention computation.

This is the inner loop extracted for reuse between full and chunked paths.

Parameters:

Name Type Description Default
query tensor

Query tensor, already transposed [kv_heads, q_heads_per_kv, tokens, head_dim] or similar

required
key tensor

Key tensor from cache [kv_heads, q_heads_per_kv, kv_len, head_dim]

required
value tensor

Value tensor from cache

required
bias tensor

Attention bias [1, kv_len] (already flattened from [num_blocks, block_size])

required
fmin tensor

Minimum float for softmax stability

required
inputL_hpu_tensors Dict[tuple, Tensor]

Cache for FA2 tensors

required
inputM_hpu_tensors Dict[tuple, Tensor]

Cache for FA2 tensors

required
kv_heads int

Number of KV heads

required
is_mla bool

Whether using MLA attention

required
w_uv Optional[tensor]

Optional MLA projection matrix

None

Returns:

Type Description
tuple[tensor, tensor, tensor]

Tuple of (unnormalized_weighted_V, local_max, local_sum)

Source code in vllm_gaudi/extension/unified.py
def _partial_attn_shared_core(query: torch.tensor,
                              key: torch.tensor,
                              value: torch.tensor,
                              bias: torch.tensor,
                              fmin: torch.tensor,
                              inputL_hpu_tensors: Dict[tuple, torch.Tensor],
                              inputM_hpu_tensors: Dict[tuple, torch.Tensor],
                              kv_heads: int,
                              is_mla: bool,
                              w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """Core shared attention computation.

    This is the inner loop extracted for reuse between full and chunked paths.

    Args:
        query: Query tensor, already transposed [kv_heads, q_heads_per_kv, tokens, head_dim] or similar
        key: Key tensor from cache [kv_heads, q_heads_per_kv, kv_len, head_dim]
        value: Value tensor from cache
        bias: Attention bias [1, kv_len] (already flattened from [num_blocks, block_size])
        fmin: Minimum float for softmax stability
        inputL_hpu_tensors: Cache for FA2 tensors
        inputM_hpu_tensors: Cache for FA2 tensors
        kv_heads: Number of KV heads
        is_mla: Whether using MLA attention
        w_uv: Optional MLA projection matrix

    Returns:
        Tuple of (unnormalized_weighted_V, local_max, local_sum)
    """
    num_heads = query.size(0) * query.size(1) if not is_mla else query.size(0)

    attn = torch.matmul(query, key.transpose(-1, -2))
    attn = attn.flatten(0, 1)
    attn = attn + bias

    # TODO: remove dtype check once full support is added for fp8 in unified attention
    if get_config().unified_attn_softmax_fa2 and attn.dtype == torch.bfloat16:
        inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(attn, fmin, inputL_hpu_tensors, inputM_hpu_tensors)
        attn, local_max, local_sum, _exp_max_fixup_hpu = torch.ops.hpu.softmax_fa2(attn,
                                                                                   inputM=inputM_hpu,
                                                                                   inputL=inputL_hpu)
        local_max = convert_cl_aligned_tensor(local_max, list(attn.shape[:-1]))
        local_sum = convert_cl_aligned_tensor(local_sum, list(attn.shape[:-1]))
    else:
        local_max = torch.maximum(attn.amax(-1), fmin)
        attn = torch.exp(attn - local_max.unsqueeze(-1))
        local_sum = attn.sum(-1)

    attn = torch.matmul(attn.unflatten(0, (kv_heads if not is_mla else num_heads, -1)), value).flatten(0, 1)

    # MLA: Extract latent part and project to full V
    if is_mla and w_uv is not None:
        latent_dim = w_uv.size(1)
        attn_latent = attn[..., :latent_dim]
        attn = torch.bmm(attn_latent, w_uv)

    return attn.transpose(0, 1), local_max.transpose(0, 1), local_sum.transpose(0, 1)

_partial_attn_shared_full

_partial_attn_shared_full(
    query: tensor,
    blocks: tensor,
    bias: tensor,
    fmin: tensor,
    inputL_hpu_tensors: Dict[tuple, Tensor],
    inputM_hpu_tensors: Dict[tuple, Tensor],
    cache_utils: CacheUtils,
    w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]

Full bias implementation of partial_attn_shared.

Source code in vllm_gaudi/extension/unified.py
def _partial_attn_shared_full(query: torch.tensor,
                              blocks: torch.tensor,
                              bias: torch.tensor,
                              fmin: torch.tensor,
                              inputL_hpu_tensors: Dict[tuple, torch.Tensor],
                              inputM_hpu_tensors: Dict[tuple, torch.Tensor],
                              cache_utils: CacheUtils,
                              w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """Full bias implementation of partial_attn_shared."""
    is_mla = w_uv is not None

    if is_mla:
        # MLA: Single latent cache contains both K and V
        latent_kv = cache_utils.fetch_shared(blocks)
        num_heads = query.size(1)
        query_t = query.transpose(0, 1).unsqueeze(1)  # [num_heads, 1, tokens, latent_dim + rope_dim]
        key = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1)
        value = latent_kv.unsqueeze(0).unsqueeze(0).expand(num_heads, 1, -1, -1)
        kv_heads = 1
    else:
        # Standard attention: Separate K and V caches
        kv_heads = cache_utils.kv_heads
        query_t = query.transpose(0, 1).unflatten(0, (kv_heads, -1))
        key, value = cache_utils.fetch_shared(blocks)

    bias_flat = bias.flatten(-2, -1).unsqueeze(0)

    return _partial_attn_shared_core(query_t, key, value, bias_flat, fmin, inputL_hpu_tensors, inputM_hpu_tensors,
                                     kv_heads, is_mla, w_uv)

block2batch

block2batch(tensor, block_mapping)

Convert from block to batch on dim=0

Source code in vllm_gaudi/extension/unified.py
def block2batch(tensor, block_mapping):
    """Convert from block to batch on dim=0"""
    return torch.matmul(block_mapping.t(), tensor.flatten(1, -1)).unflatten(-1, tensor.shape[1:])

convert_cl_aligned_tensor

convert_cl_aligned_tensor(
    input_hpu, reference_size
) -> tensor

Convert a CL-aligned tensor to the reference size

Source code in vllm_gaudi/extension/unified.py
def convert_cl_aligned_tensor(input_hpu, reference_size) -> torch.tensor:
    """Convert a CL-aligned tensor to the reference size"""
    vec_size, pack_size = get_vecsize_packsize(input_hpu.dtype)
    input_hpu_shape = list(reference_size)
    input_hpu_shape[-1] = -1
    input_hpu_shape.append(vec_size)
    input_hpu = input_hpu.reshape(input_hpu_shape)
    input_hpu = input_hpu[..., :pack_size]
    input_hpu = torch.flatten(input_hpu, start_dim=-2, end_dim=-1)
    input_hpu = input_hpu[..., :reference_size[-1]]
    return input_hpu

create_softmax_fa2_input_tensors

create_softmax_fa2_input_tensors(
    attn: tensor,
    fmin: tensor,
    inputL_hpu_tensors: Dict[tuple, Tensor],
    inputM_hpu_tensors: Dict[tuple, Tensor],
) -> tuple[tensor, tensor]

Create dummy input tensors for the softmax_fa2 operation.

Source code in vllm_gaudi/extension/unified.py
def create_softmax_fa2_input_tensors(
        attn: torch.tensor, fmin: torch.tensor, inputL_hpu_tensors: Dict[tuple, torch.Tensor],
        inputM_hpu_tensors: Dict[tuple, torch.Tensor]) -> tuple[torch.tensor, torch.tensor]:
    """Create dummy input tensors for the softmax_fa2 operation."""
    # Assumes input tensors are already allocated with correct shape.
    # The filling is done on each call to avoid potential stale data issues.
    vec_size, pack_size = get_vecsize_packsize(attn.dtype)
    retained_shape = list(attn.shape[:-1])
    retained_shape[-1] = get_last_dim_size(retained_shape[-1], vec_size, pack_size)
    t_retained_shape = tuple(retained_shape)

    # Convert fmin to scalar once
    fmin_val = fmin.item() if isinstance(fmin, torch.Tensor) else fmin

    if t_retained_shape not in inputM_hpu_tensors:
        print("Allocating new input tensors for shape:", t_retained_shape, "for attn shape:", attn.shape)
        return torch.full(retained_shape, fmin, dtype=attn.dtype, device='hpu'), torch.zeros(retained_shape,
                                                                                             dtype=attn.dtype,
                                                                                             device="hpu")
    torch.hpu.synchronize()
    inputL_hpu_tensors[t_retained_shape].zero_()
    inputM_hpu_tensors[t_retained_shape].fill_(fmin)
    return inputM_hpu_tensors[t_retained_shape], inputL_hpu_tensors[t_retained_shape]

get_last_dim_size

get_last_dim_size(last_dim, vec_size, pack_size)
Source code in vllm_gaudi/extension/unified.py
def get_last_dim_size(last_dim, vec_size, pack_size):
    return math.ceil(last_dim / pack_size) * vec_size

get_vecsize_packsize

get_vecsize_packsize(dtype: dtype) -> tuple[int, int]

Get vecsize and packsize for given dtype

Source code in vllm_gaudi/extension/unified.py
def get_vecsize_packsize(dtype: torch.dtype) -> tuple[int, int]:
    """Get vecsize and packsize for given dtype"""
    pack_size = 8
    if hpu_ops.is_hpu_gaudi3:
        return 128 if dtype == torch.bfloat16 else 64, pack_size
    return 1, pack_size

merge

merge(*attn_results: tensor, feps: tensor) -> tensor

Merge partial attention values into final attn score

Source code in vllm_gaudi/extension/unified.py
def merge(*attn_results: torch.tensor, feps: torch.tensor) -> torch.tensor:
    """Merge partial attention values into final attn score"""
    all_attn, all_max, all_sum = zip(*attn_results)
    global_max = functools.reduce(optional(torch.maximum), all_max)
    calc_adjustment = optional(lambda x: torch.exp((x - global_max)))
    adjust = optional(lambda x, a: x * a)
    all_adj = [calc_adjustment(x) for x in all_max]
    global_sum = functools.reduce(optional(torch.add), [adjust(s, a) for s, a in zip(all_sum, all_adj)])
    global_sum = torch.maximum(global_sum, feps)
    rescale = optional(lambda x, adj: x * (adj / global_sum).unsqueeze(-1))
    attn = [rescale(attn, adj) for attn, adj in zip(all_attn, all_adj)]
    attn = functools.reduce(optional(torch.add), attn)
    return attn

online_merge

online_merge(
    *attn_results: tuple[tensor, tensor, tensor],
    feps: tensor,
) -> Optional[tensor]

Merge partial attention values using online (incremental) algorithm.

Alternative to merge() that uses online_merge_step for incremental merging. This approach is more memory efficient as it doesn't need to keep all intermediate results simultaneously.

Parameters:

Name Type Description Default
attn_results tuple[tensor, tensor, tensor]

Variable number of (attn, max, sum) tuples

()
feps tensor

Small epsilon for numerical stability

required

Returns:

Type Description
Optional[tensor]

Final normalized attention output, or None if all inputs are None

Source code in vllm_gaudi/extension/unified.py
def online_merge(*attn_results: tuple[torch.tensor, torch.tensor, torch.tensor],
                 feps: torch.tensor) -> Optional[torch.tensor]:
    """Merge partial attention values using online (incremental) algorithm.

    Alternative to merge() that uses online_merge_step for incremental merging.
    This approach is more memory efficient as it doesn't need to keep all
    intermediate results simultaneously.

    Args:
        attn_results: Variable number of (attn, max, sum) tuples
        feps: Small epsilon for numerical stability

    Returns:
        Final normalized attention output, or None if all inputs are None
    """
    acc_attn = None
    acc_max = None
    acc_sum = None

    for attn, max_val, sum_val in attn_results:
        acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, attn, max_val, sum_val)

    if acc_attn is None:
        return None

    # Final normalization
    acc_sum = torch.maximum(acc_sum, feps)
    return acc_attn / acc_sum.unsqueeze(-1)

online_merge_step

online_merge_step(
    acc_attn: Optional[tensor],
    acc_max: Optional[tensor],
    acc_sum: Optional[tensor],
    new_attn: Optional[tensor],
    new_max: Optional[tensor],
    new_sum: Optional[tensor],
) -> tuple[
    Optional[tensor], Optional[tensor], Optional[tensor]
]

Incrementally merge attention results using flash-attention style rescaling.

This implements the online softmax algorithm where we maintain running unnormalized weighted values, max, and sum. The final normalization (dividing by sum) is done at the end.

Parameters:

Name Type Description Default
acc_attn Optional[tensor]

Accumulated unnormalized weighted V [tokens, heads, head_dim] or None

required
acc_max Optional[tensor]

Accumulated max values [tokens, heads] or None

required
acc_sum Optional[tensor]

Accumulated sum of exp values [tokens, heads] or None

required
new_attn Optional[tensor]

New unnormalized weighted V to merge

required
new_max Optional[tensor]

New max values to merge

required
new_sum Optional[tensor]

New sum of exp values to merge

required

Returns:

Type Description
tuple[Optional[tensor], Optional[tensor], Optional[tensor]]

Tuple of (merged_attn, merged_max, merged_sum)

Source code in vllm_gaudi/extension/unified.py
def online_merge_step(
    acc_attn: Optional[torch.tensor],
    acc_max: Optional[torch.tensor],
    acc_sum: Optional[torch.tensor],
    new_attn: Optional[torch.tensor],
    new_max: Optional[torch.tensor],
    new_sum: Optional[torch.tensor],
) -> tuple[Optional[torch.tensor], Optional[torch.tensor], Optional[torch.tensor]]:
    """Incrementally merge attention results using flash-attention style rescaling.

    This implements the online softmax algorithm where we maintain running
    unnormalized weighted values, max, and sum. The final normalization
    (dividing by sum) is done at the end.

    Args:
        acc_attn: Accumulated unnormalized weighted V [tokens, heads, head_dim] or None
        acc_max: Accumulated max values [tokens, heads] or None  
        acc_sum: Accumulated sum of exp values [tokens, heads] or None
        new_attn: New unnormalized weighted V to merge
        new_max: New max values to merge
        new_sum: New sum of exp values to merge

    Returns:
        Tuple of (merged_attn, merged_max, merged_sum)
    """
    if new_attn is None:
        return acc_attn, acc_max, acc_sum
    if acc_attn is None:
        return new_attn, new_max, new_sum

    # Flash-attention style merge
    merged_max = torch.maximum(acc_max, new_max)
    old_scale = torch.exp(acc_max - merged_max)
    new_scale = torch.exp(new_max - merged_max)

    merged_attn = acc_attn * old_scale.unsqueeze(-1) + new_attn * new_scale.unsqueeze(-1)
    merged_sum = acc_sum * old_scale + new_sum * new_scale

    return merged_attn, merged_max, merged_sum

optional

optional(op)

Wrap an operation to support handling None values

Source code in vllm_gaudi/extension/unified.py
def optional(op):
    """Wrap an operation to support handling None values"""

    # Examples for binary operation:
    #   op(None, None) -> None
    #   op(None, B) -> B
    #   op(A, None) -> A
    #   op(A, B) -> op(A, B)
    # Examples for unary operation:
    #   op(None) -> None
    #   op(A) -> op(A)
    def opt_impl(*args):
        not_none = [a for a in args if a is not None]
        if len(not_none) == len(args):
            return op(*args)
        elif len(not_none) == 1:
            return not_none[0]
        else:
            return None

    return opt_impl

partial_attn_causal

partial_attn_causal(
    query: tensor,
    key: tensor,
    value: tensor,
    bias: Optional[tensor],
    slice_size: int,
    fmin: tensor,
    inputL_hpu_tensors: Dict[tuple, Tensor],
    inputM_hpu_tensors: Dict[tuple, Tensor],
    w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]

Partial attention where qkv are assumed to be causal between slices

Parameters:

Name Type Description Default
w_uv Optional[tensor]

Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. If provided, value is assumed to be in latent space and will be projected.

None
Source code in vllm_gaudi/extension/unified.py
def partial_attn_causal(query: torch.tensor,
                        key: torch.tensor,
                        value: torch.tensor,
                        bias: Optional[torch.tensor],
                        slice_size: int,
                        fmin: torch.tensor,
                        inputL_hpu_tensors: Dict[tuple, torch.Tensor],
                        inputM_hpu_tensors: Dict[tuple, torch.Tensor],
                        w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """Partial attention where qkv are assumed to be causal between slices

    Args:
        w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim].
              If provided, value is assumed to be in latent space and will be projected.
    """
    if bias is None:
        return (None, None, None)

    num_slices = math.ceil(query.size(0) / slice_size)
    kv_heads = key.size(1)

    query = query.transpose(0, 1).unflatten(0, (kv_heads, -1))
    key = key.transpose(0, 1).unflatten(0, (kv_heads, -1))
    value = value.transpose(0, 1).unflatten(0, (kv_heads, -1))

    attn_slices = []
    max_slices = []
    sum_slices = []

    for i in range(num_slices):
        q_min = i * slice_size
        q_max = q_min + slice_size
        q = query[:, :, q_min:q_max, :]
        k = key[:, :, 0:q_max, :]
        v = value[:, :, 0:q_max, :]
        b = bias[q_min:q_max, 0:q_max]

        s_attn = torch.matmul(q, k.transpose(-1, -2)) + b.unsqueeze(0).unsqueeze(0)
        # TODO: remove dtype check once full support is added for fp8 in unified attention
        if get_config().unified_attn_softmax_fa2 and s_attn.dtype == torch.bfloat16:
            inputM_hpu, inputL_hpu = create_softmax_fa2_input_tensors(s_attn, fmin, inputL_hpu_tensors,
                                                                      inputM_hpu_tensors)
            s_attn, s_max, s_sum, _exp_max_fixup_hpu = torch.ops.hpu.softmax_fa2(s_attn,
                                                                                 inputM=inputM_hpu,
                                                                                 inputL=inputL_hpu)
            s_max = convert_cl_aligned_tensor(s_max, list(s_attn.shape[:-1]))
            s_sum = convert_cl_aligned_tensor(s_sum, list(s_attn.shape[:-1]))
        else:
            s_max = torch.maximum(s_attn.amax(-1), fmin)
            s_attn = torch.exp(s_attn - s_max.unsqueeze(-1))
            s_sum = torch.sum(s_attn, -1)

        # Attention: s_attn @ v
        s_attn = torch.matmul(s_attn, v)

        # MLA: Project from latent V to full V
        if w_uv is not None:
            orig_shape = s_attn.shape
            s_attn = s_attn.flatten(0, 1)  # [num_heads, tokens, latent_dim]
            s_attn = torch.bmm(s_attn, w_uv)  # [num_heads, tokens, v_head_dim]
            s_attn = s_attn.unflatten(0, orig_shape[:2])  # [kv_heads, q_heads_per_kv, tokens, v_head_dim]

        attn_slices.append(s_attn)
        max_slices.append(s_max)
        sum_slices.append(s_sum)

    def combine(slices):
        """Combine all slices"""
        return torch.cat(slices, dim=2).flatten(0, 1).transpose(0, 1)

    return combine(attn_slices), combine(max_slices), combine(sum_slices)

partial_attn_shared

partial_attn_shared(
    query: tensor,
    blocks: tensor,
    bias: Optional[tensor],
    fmin: tensor,
    inputL_hpu_tensors: Dict[tuple, Tensor],
    inputM_hpu_tensors: Dict[tuple, Tensor],
    cache_utils: CacheUtils,
    dtype: dtype,
    w_uv: Optional[tensor] = None,
    chunked_data: Optional[
        SharedBlockChunkedBiasData
    ] = None,
    chunk_size: int = 0,
) -> tuple[tensor, tensor, tensor]

Partial attention where all shared blocks are compared with whole query.

Supports two modes: 1. Full bias mode (default): bias tensor is provided, process all blocks at once 2. Chunked mode: chunk_size > 0, process blocks in chunks - If bias is provided, slice from it - If bias is None but chunked_data is provided, generate bias per chunk from dense block_usages

Parameters:

Name Type Description Default
query tensor

Query tensor [tokens, num_heads, head_dim]

required
blocks tensor

Shared block indices [num_shared_blocks]

required
bias Optional[tensor]

Pre-computed bias tensor [query_len, num_blocks, block_size]. Can be None for chunked generation.

required
fmin tensor

Minimum float value for softmax stability

required
inputL_hpu_tensors Dict[tuple, Tensor]

Cache for softmax input tensors

required
inputM_hpu_tensors Dict[tuple, Tensor]

Cache for softmax input tensors

required
cache_utils CacheUtils

Cache utilities for fetching KV

required
dtype dtype

Output dtype for bias generation

required
w_uv Optional[tensor]

Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]

None
chunked_data Optional[SharedBlockChunkedBiasData]

Metadata for chunked processing (contains dense block_usages)

None
chunk_size int

Number of blocks per chunk (0 = full mode, >0 = chunked mode)

0

Returns:

Type Description
tuple[tensor, tensor, tensor]

Tuple of (unnormalized_weighted_V, local_max, local_sum)

Source code in vllm_gaudi/extension/unified.py
def partial_attn_shared(query: torch.tensor,
                        blocks: torch.tensor,
                        bias: Optional[torch.tensor],
                        fmin: torch.tensor,
                        inputL_hpu_tensors: Dict[tuple, torch.Tensor],
                        inputM_hpu_tensors: Dict[tuple, torch.Tensor],
                        cache_utils: CacheUtils,
                        dtype: torch.dtype,
                        w_uv: Optional[torch.tensor] = None,
                        chunked_data: Optional[SharedBlockChunkedBiasData] = None,
                        chunk_size: int = 0) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """Partial attention where all shared blocks are compared with whole query.

    Supports two modes:
    1. Full bias mode (default): bias tensor is provided, process all blocks at once
    2. Chunked mode: chunk_size > 0, process blocks in chunks
       - If bias is provided, slice from it
       - If bias is None but chunked_data is provided, generate bias per chunk from dense block_usages

    Args:
        query: Query tensor [tokens, num_heads, head_dim]
        blocks: Shared block indices [num_shared_blocks]
        bias: Pre-computed bias tensor [query_len, num_blocks, block_size]. Can be None for chunked generation.
        fmin: Minimum float value for softmax stability
        inputL_hpu_tensors: Cache for softmax input tensors
        inputM_hpu_tensors: Cache for softmax input tensors
        cache_utils: Cache utilities for fetching KV
        dtype: Output dtype for bias generation
        w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]
        chunked_data: Metadata for chunked processing (contains dense block_usages)
        chunk_size: Number of blocks per chunk (0 = full mode, >0 = chunked mode)

    Returns:
        Tuple of (unnormalized_weighted_V, local_max, local_sum)
    """
    # Determine mode
    use_chunked = chunk_size > 0 and chunked_data is not None

    if not use_chunked:
        # Full bias mode - original implementation
        if bias is None:
            return (None, None, None)
        return _partial_attn_shared_full(query, blocks, bias, fmin, inputL_hpu_tensors, inputM_hpu_tensors, cache_utils,
                                         w_uv)
    else:
        # Chunked mode - process blocks in chunks
        # bias can be None for chunked generation (will generate from chunked_data.block_usages per chunk)
        if blocks is None:
            return (None, None, None)
        return _partial_attn_shared_chunked(query, blocks, bias, chunked_data, chunk_size, fmin, inputL_hpu_tensors,
                                            inputM_hpu_tensors, cache_utils, dtype, w_uv)

partial_attn_unique

partial_attn_unique(
    query: tensor,
    blocks: tensor,
    block_mapping: tensor,
    bias: Optional[tensor],
    fmin: tensor,
    cache_utils: CacheUtils,
    w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]

Partial attention where all blocks are used by max one query

Parameters:

Name Type Description Default
w_uv Optional[tensor]

Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. If provided, assumes MLA mode where query/key/value are in latent space.

None
Source code in vllm_gaudi/extension/unified.py
def partial_attn_unique(query: torch.tensor,
                        blocks: torch.tensor,
                        block_mapping: torch.tensor,
                        bias: Optional[torch.tensor],
                        fmin: torch.tensor,
                        cache_utils: CacheUtils,
                        w_uv: Optional[torch.tensor] = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """Partial attention where all blocks are used by max one query

    Args:
        w_uv: Optional MLA projection matrix [num_heads, latent_dim, v_head_dim].
              If provided, assumes MLA mode where query/key/value are in latent space.
    """
    if bias is None:
        return (None, None, None)

    batch_size = query.size(0)
    is_mla = w_uv is not None

    if is_mla:
        # MLA: Single latent cache
        num_heads = query.size(1)
        latent_kv = cache_utils.fetch_unique(blocks)
        latent_kv = latent_kv.unflatten(0, (-1, cache_utils.block_size))

        query = query.index_select(0, block_mapping).unflatten(1, (1, num_heads)).unsqueeze(-2)
        key = latent_kv.unsqueeze(1).unsqueeze(1).expand(-1, 1, num_heads, -1, -1)
        value = latent_kv.unsqueeze(1).unsqueeze(1).expand(-1, 1, num_heads, -1, -1)
    else:
        # Standard attention
        kv_heads = cache_utils.kv_heads
        query = query.index_select(0, block_mapping).unflatten(1, (kv_heads, -1)).unsqueeze(-2)
        key, value = cache_utils.fetch_unique(blocks)

    block_mapping_2d = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size).to(query.dtype)

    attn = torch.matmul(query, key.transpose(-1, -2))
    attn = attn + bias.unsqueeze(1).unsqueeze(1).unsqueeze(1)
    block_max = torch.maximum(attn.amax(-1), fmin)
    attn = torch.exp(attn - block_max.unsqueeze(-1))
    block_sum = attn.sum(-1)
    attn = torch.matmul(attn, value)

    # MLA: Extract latent part and project to full V
    if is_mla:
        latent_dim = w_uv.size(1)
        attn_latent = attn[..., :latent_dim]  # [num_blocks, 1, num_heads, 1, latent_dim]
        attn_latent = attn_latent.squeeze(1).squeeze(-2).transpose(0, 1)  # [num_heads, num_blocks, latent_dim]
        attn = torch.bmm(attn_latent, w_uv)  # [num_heads, num_blocks, v_head_dim]
        attn = attn.transpose(0, 1).unsqueeze(1).unsqueeze(-2)  # [num_blocks, 1, num_heads, 1, v_head_dim]

    group_max = reduce_max(block_max, batch_size, block_mapping)
    block_adjustment = torch.exp(block_max - group_max.index_select(0, block_mapping))
    block_sum = block_sum * block_adjustment
    group_sum = block2batch(block_sum, block_mapping_2d)
    attn = attn * block_adjustment.unsqueeze(-1)
    attn = block2batch(attn, block_mapping_2d)
    return (attn.flatten(1, 3), group_max.flatten(1, 3), group_sum.flatten(1, 3))

reduce_max

reduce_max(
    local_max: tensor, batch_size: int, mapping: tensor
)

Reduce local block minima to per-group minimum

Source code in vllm_gaudi/extension/unified.py
def reduce_max(local_max: torch.tensor, batch_size: int, mapping: torch.tensor):
    """Reduce local block minima to per-group minimum"""
    shape_suffix = local_max.shape[1:]
    local_max = local_max.flatten(1, -1)
    group_max = torch.full([batch_size, *local_max.shape[1:]],
                           -math.inf,
                           dtype=local_max.dtype,
                           device=local_max.device)
    group_max = group_max.index_reduce_(0, mapping, local_max, 'amax')
    group_max = group_max.unflatten(-1, shape_suffix)
    return group_max

unified_attn

unified_attn(
    query: tensor,
    key: tensor,
    value: tensor,
    key_cache: tensor,
    value_cache: tensor,
    scale: float,
    metadata: HPUUnifiedAttentionMetadata,
) -> tensor

Main entry point for unified attention

Source code in vllm_gaudi/extension/unified.py
def unified_attn(query: torch.tensor, key: torch.tensor, value: torch.tensor, key_cache: torch.tensor,
                 value_cache: torch.tensor, scale: float, metadata: HPUUnifiedAttentionMetadata) -> torch.tensor:
    """Main entry point for unified attention"""

    scaled_query = query * scale
    cache_utils = CacheUtils(key_cache, value_cache, metadata.block_size)

    use_online_merge = metadata.online_merge
    split_graphs = metadata.split_graphs

    if use_online_merge:
        # Online merge: compute and merge incrementally to avoid large intermediate buffers
        acc_attn, acc_max, acc_sum = None, None, None

    if split_graphs:
        htorch.core.mark_step()

    # 1. Causal attention
    causal = partial_attn_causal(query=scaled_query,
                                 key=key,
                                 value=value,
                                 bias=metadata.causal_bias,
                                 slice_size=metadata.causal_width,
                                 fmin=metadata.fmin,
                                 inputL_hpu_tensors=metadata.inputL_hpu_tensors,
                                 inputM_hpu_tensors=metadata.inputM_hpu_tensors,
                                 w_uv=None)
    if use_online_merge:
        acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal)

    if split_graphs:
        htorch.core.mark_step()

    # 2. Shared attention
    shared = partial_attn_shared(query=scaled_query,
                                 blocks=metadata.shared_blocks,
                                 bias=metadata.shared_bias,
                                 fmin=metadata.fmin,
                                 inputL_hpu_tensors=metadata.inputL_hpu_tensors,
                                 inputM_hpu_tensors=metadata.inputM_hpu_tensors,
                                 cache_utils=cache_utils,
                                 dtype=query.dtype,
                                 w_uv=None,
                                 chunked_data=metadata.shared_bias_chunked,
                                 chunk_size=metadata.shared_chunk_size)
    if use_online_merge:
        acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared)

    if split_graphs:
        htorch.core.mark_step()

    # 3. Unique attention
    unique = partial_attn_unique(query=scaled_query,
                                 blocks=metadata.unique_blocks,
                                 block_mapping=metadata.unique_block_mapping,
                                 bias=metadata.unique_bias,
                                 fmin=metadata.fmin,
                                 cache_utils=cache_utils,
                                 w_uv=None)
    if use_online_merge:
        acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique)

    if split_graphs:
        htorch.core.mark_step()

    # Final normalization
    if use_online_merge:
        if acc_attn is None:
            return query
        acc_sum = torch.maximum(acc_sum, metadata.feps)
        attn = acc_attn / acc_sum.unsqueeze(-1)
    else:
        attn = merge(causal, shared, unique, feps=metadata.feps)
        if attn is None:
            return query
    return attn

unified_mla

unified_mla(
    query: Optional[tensor],
    key: Optional[tensor],
    value: Optional[tensor],
    latent_cache: tensor,
    scale: float,
    metadata: HPUUnifiedAttentionMetadata,
    w_uv: tensor,
    query_latent: Optional[tensor] = None,
) -> tensor

Main entry point for Unified MLA

Parameters:

Name Type Description Default
query Optional[tensor]

Query tensor for causal path (already uncompressed) [tokens, num_heads, qk_head_dim] None if only cached attention is needed.

required
key Optional[tensor]

Key tensor for causal part [tokens, num_heads, qk_head_dim]. None for cached-only.

required
value Optional[tensor]

Value tensor for causal part in latent space [tokens, num_heads, latent_dim]. None for cached-only.

required
latent_cache tensor

Cached latent KV [num_blocks * block_size, latent_dim + rope_dim]

required
scale float

Attention scale factor

required
metadata HPUUnifiedAttentionMetadata

Unified attention metadata

required
w_uv tensor

Projection matrix from latent to full V [num_heads, latent_dim, v_head_dim]

required
query_latent Optional[tensor]

Query tensor for cached path (in latent space) [tokens, num_heads, latent_dim + rope_dim] None if only causal attention is needed.

None
use_online_merge

If True, use online (incremental) merge algorithm. Merges after each partial attention to avoid large intermediate buffers. If False, use offline (single-pass) merge algorithm.

required

Returns:

Type Description
tensor

Attention output [tokens, num_heads * v_head_dim]

Note
  • For causal-only: pass query/key/value, set query_latent=None
  • For cached-only: pass query_latent, set query/key/value=None
  • For mixed batches: pass both query and query_latent
Source code in vllm_gaudi/extension/unified.py
def unified_mla(query: Optional[torch.tensor],
                key: Optional[torch.tensor],
                value: Optional[torch.tensor],
                latent_cache: torch.tensor,
                scale: float,
                metadata: HPUUnifiedAttentionMetadata,
                w_uv: torch.tensor,
                query_latent: Optional[torch.tensor] = None) -> torch.tensor:
    """Main entry point for Unified MLA

    Args:
        query: Query tensor for causal path (already uncompressed) [tokens, num_heads, qk_head_dim]
               None if only cached attention is needed.
        key: Key tensor for causal part [tokens, num_heads, qk_head_dim]. None for cached-only.
        value: Value tensor for causal part in latent space [tokens, num_heads, latent_dim]. None for cached-only.
        latent_cache: Cached latent KV [num_blocks * block_size, latent_dim + rope_dim]
        scale: Attention scale factor
        metadata: Unified attention metadata
        w_uv: Projection matrix from latent to full V [num_heads, latent_dim, v_head_dim]
        query_latent: Query tensor for cached path (in latent space) [tokens, num_heads, latent_dim + rope_dim]
                     None if only causal attention is needed.
        use_online_merge: If True, use online (incremental) merge algorithm.
                         Merges after each partial attention to avoid large intermediate buffers.
                         If False, use offline (single-pass) merge algorithm.

    Returns:
        Attention output [tokens, num_heads * v_head_dim]

    Note:
        - For causal-only: pass query/key/value, set query_latent=None
        - For cached-only: pass query_latent, set query/key/value=None
        - For mixed batches: pass both query and query_latent
    """
    assert query is not None or query_latent is not None, \
        "At least one of query or query_latent must be provided"

    # Use appropriate query for each path
    scaled_query_causal = query * scale if query is not None else None
    scaled_query_latent = query_latent * scale if query_latent is not None else None

    # MLA: latent cache has no head dimension, value_cache is None (stored in same cache)
    cache_utils = CacheUtils(latent_cache, value_cache=None, block_size=metadata.block_size, is_mla=True)
    use_online_merge = metadata.online_merge
    split_graphs = metadata.split_graphs

    if use_online_merge:
        # Online merge: compute and merge incrementally to avoid large intermediate buffers
        acc_attn, acc_max, acc_sum = None, None, None

    if split_graphs:
        htorch.core.mark_step()

    # Causal: compute-friendly path (expand K/V from latent)
    # key and value already expanded by caller
    # w_uv projection applied by unified function
    causal = partial_attn_causal(query=scaled_query_causal,
                                 key=key,
                                 value=value,
                                 bias=metadata.causal_bias,
                                 slice_size=metadata.causal_width,
                                 fmin=metadata.fmin,
                                 inputL_hpu_tensors=metadata.inputL_hpu_tensors,
                                 inputM_hpu_tensors=metadata.inputM_hpu_tensors,
                                 w_uv=w_uv) if scaled_query_causal is not None else (None, None, None)
    if use_online_merge:
        acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *causal)

    if split_graphs:
        htorch.core.mark_step()

    # Shared/Unique: memory-friendly path (Q in latent space, fetch cached latent KV)
    # query_latent is already transformed to latent space by caller
    # For these paths, we need to expand K/V from cached latent vectors

    # Single call handles both full and chunked modes
    if scaled_query_latent is not None:
        shared = partial_attn_shared(query=scaled_query_latent,
                                     blocks=metadata.shared_blocks,
                                     bias=metadata.shared_bias,
                                     fmin=metadata.fmin,
                                     inputL_hpu_tensors=metadata.inputL_hpu_tensors,
                                     inputM_hpu_tensors=metadata.inputM_hpu_tensors,
                                     cache_utils=cache_utils,
                                     dtype=scaled_query_latent.dtype,
                                     w_uv=w_uv,
                                     chunked_data=metadata.shared_bias_chunked,
                                     chunk_size=metadata.shared_chunk_size)
        if use_online_merge:
            acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *shared)
    else:
        shared = (None, None, None)

    if split_graphs:
        htorch.core.mark_step()

    unique = partial_attn_unique(query=scaled_query_latent,
                                 blocks=metadata.unique_blocks,
                                 block_mapping=metadata.unique_block_mapping,
                                 bias=metadata.unique_bias,
                                 fmin=metadata.fmin,
                                 cache_utils=cache_utils,
                                 w_uv=w_uv) if scaled_query_latent is not None else (None, None, None)
    if use_online_merge:
        acc_attn, acc_max, acc_sum = online_merge_step(acc_attn, acc_max, acc_sum, *unique)

    if split_graphs:
        htorch.core.mark_step()

    if use_online_merge:
        if acc_attn is None:
            if query is not None:
                return query.flatten(-2, -1)  # [tokens, num_heads * head_dim]
            else:
                return query_latent.flatten(-2, -1)  # [tokens, num_heads * head_dim]
        acc_sum = torch.maximum(acc_sum, metadata.feps)
        attn = acc_attn / acc_sum.unsqueeze(-1)
    else:
        attn = merge(causal, shared, unique, feps=metadata.feps)
        if attn is None:
            # No attention computed, return original query
            # Use whichever query was provided
            # FIXME(kzawora): I'm not quite sure if that's correct, needs verification
            if query is not None:
                return query.flatten(-2, -1)  # [tokens, num_heads * head_dim]
            else:
                return query_latent.flatten(-2, -1)  # [tokens, num_heads * head_dim]
    return attn