vllm_gaudi.extension.unified
¶
CacheUtils
¶
Helper utilities for kv-cache
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
is_mla
|
If True, cache stores MLA latent vectors (no head dimension, single cache). If False, standard attention with per-head K/V caches. |
False
|
Source code in vllm_gaudi/extension/unified.py
__init__
¶
Source code in vllm_gaudi/extension/unified.py
_fetch_all
¶
Fetch both key and values using selected function
Source code in vllm_gaudi/extension/unified.py
_fetch_single_shared
¶
Fetch selected shared blocks from given cache
Source code in vllm_gaudi/extension/unified.py
_fetch_single_unique
¶
Fetch selected unique blocks from given cache
Source code in vllm_gaudi/extension/unified.py
fetch_shared
¶
HPUUnifiedAttentionMetadata
dataclass
¶
Source code in vllm_gaudi/extension/unified.py
__init__
¶
__init__(
block_size: int,
slot_mapping: tensor,
causal_bias: Optional[tensor],
causal_width: int,
shared_blocks: Optional[tensor],
shared_bias: Optional[tensor],
unique_blocks: Optional[tensor] | Optional[int],
unique_block_mapping: Optional[tensor],
unique_bias: Optional[tensor],
fmin: tensor,
feps: tensor,
inputL_hpu_tensors: Optional[Dict[tuple, Tensor]],
inputM_hpu_tensors: Optional[Dict[tuple, Tensor]],
) -> None
num_blocks
¶
Source code in vllm_gaudi/extension/unified.py
block2batch
¶
convert_cl_aligned_tensor
¶
convert_cl_aligned_tensor(
input_hpu, reference_size
) -> tensor
Convert a CL-aligned tensor to the reference size
Source code in vllm_gaudi/extension/unified.py
create_softmax_fa2_input_tensors
¶
create_softmax_fa2_input_tensors(
attn: tensor,
fmin: tensor,
inputL_hpu_tensors: Dict[tuple, Tensor],
inputM_hpu_tensors: Dict[tuple, Tensor],
) -> tuple[tensor, tensor]
Create dummy input tensors for the softmax_fa2 operation.
Source code in vllm_gaudi/extension/unified.py
get_last_dim_size
¶
get_vecsize_packsize
¶
Get vecsize and packsize for given dtype
merge
¶
Merge partial attention values into final attn score
Source code in vllm_gaudi/extension/unified.py
optional
¶
Wrap an operation to support handling None values
Source code in vllm_gaudi/extension/unified.py
partial_attn_causal
¶
partial_attn_causal(
query: tensor,
key: tensor,
value: tensor,
bias: Optional[tensor],
slice_size: int,
fmin: tensor,
inputL_hpu_tensors: Dict[tuple, Tensor],
inputM_hpu_tensors: Dict[tuple, Tensor],
w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]
Partial attention where qkv are assumed to be causal between slices
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
w_uv
|
Optional[tensor]
|
Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. If provided, value is assumed to be in latent space and will be projected. |
None
|
Source code in vllm_gaudi/extension/unified.py
partial_attn_shared
¶
partial_attn_shared(
query: tensor,
blocks: tensor,
bias: Optional[tensor],
fmin: tensor,
inputL_hpu_tensors: Dict[tuple, Tensor],
inputM_hpu_tensors: Dict[tuple, Tensor],
cache_utils: CacheUtils,
w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]
Partial attention where all shared blocks are compared with whole query
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
w_uv
|
Optional[tensor]
|
Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. If provided, assumes MLA mode where query/key/value are in latent space. |
None
|
Source code in vllm_gaudi/extension/unified.py
partial_attn_unique
¶
partial_attn_unique(
query: tensor,
blocks: tensor,
block_mapping: tensor,
bias: Optional[tensor],
fmin: tensor,
cache_utils: CacheUtils,
w_uv: Optional[tensor] = None,
) -> tuple[tensor, tensor, tensor]
Partial attention where all blocks are used by max one query
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
w_uv
|
Optional[tensor]
|
Optional MLA projection matrix [num_heads, latent_dim, v_head_dim]. If provided, assumes MLA mode where query/key/value are in latent space. |
None
|
Source code in vllm_gaudi/extension/unified.py
reduce_max
¶
Reduce local block minima to per-group minimum
Source code in vllm_gaudi/extension/unified.py
unified_attn
¶
unified_attn(
query: tensor,
key: tensor,
value: tensor,
key_cache: tensor,
value_cache: tensor,
scale: float,
metadata: HPUUnifiedAttentionMetadata,
) -> tensor
Main entry point for unified attention
Source code in vllm_gaudi/extension/unified.py
unified_mla
¶
unified_mla(
query: Optional[tensor],
key: Optional[tensor],
value: Optional[tensor],
latent_cache: tensor,
scale: float,
metadata: HPUUnifiedAttentionMetadata,
w_uv: tensor,
query_latent: Optional[tensor] = None,
) -> tensor
Main entry point for Unified MLA
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
query
|
Optional[tensor]
|
Query tensor for causal path (already uncompressed) [tokens, num_heads, qk_head_dim] None if only cached attention is needed. |
required |
key
|
Optional[tensor]
|
Key tensor for causal part [tokens, num_heads, qk_head_dim]. None for cached-only. |
required |
value
|
Optional[tensor]
|
Value tensor for causal part in latent space [tokens, num_heads, latent_dim]. None for cached-only. |
required |
latent_cache
|
tensor
|
Cached latent KV [num_blocks * block_size, latent_dim + rope_dim] |
required |
scale
|
float
|
Attention scale factor |
required |
metadata
|
HPUUnifiedAttentionMetadata
|
Unified attention metadata |
required |
w_uv
|
tensor
|
Projection matrix from latent to full V [num_heads, latent_dim, v_head_dim] |
required |
query_latent
|
Optional[tensor]
|
Query tensor for cached path (in latent space) [tokens, num_heads, latent_dim + rope_dim] None if only causal attention is needed. |
None
|
Returns:
| Type | Description |
|---|---|
tensor
|
Attention output [tokens, num_heads * v_head_dim] |
Note
- For causal-only: pass query/key/value, set query_latent=None
- For cached-only: pass query_latent, set query/key/value=None
- For mixed batches: pass both query and query_latent
Source code in vllm_gaudi/extension/unified.py
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 | |