class DeepseekV4XPUAttention(DeepseekV4Attention):
"""XPU sparse MLA attention layer for DeepSeek V4."""
backend_cls = DeepseekV4XPUSparseBackend
use_flashmla_fp8_layout = True
def __init__(self, *args, **kwargs) -> None:
# torch.cuda.Event() raises RuntimeError on XPU ("dummy base class").
# The Base and DeepseekV4Indexer both create cuda Events in __init__, so
# we temporarily redirect torch.cuda.Event → torch.xpu.Event.
_orig_event = torch.cuda.Event
torch.cuda.Event = torch.xpu.Event # type: ignore[misc]
try:
super().__init__(*args, **kwargs)
finally:
torch.cuda.Event = _orig_event # type: ignore[misc]
def _fused_qnorm_rope_kv_insert(self, q, kv, positions, attn_metadata):
from typing import cast
if not isinstance(attn_metadata, dict):
# Profile run: no-op, just return q (no padding needed on XPU).
return q
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
from vllm.models.deepseek_v4.xpu.xpu_qnorm_rope_kv_fp8_insert import (
xpu_qnorm_rope_kv_fp8_insert,
)
xpu_qnorm_rope_kv_fp8_insert(
q,
kv,
self.swa_cache_layer.kv_cache,
swa_metadata.slot_mapping,
positions,
self.rotary_emb.cos_sin_cache,
self.eps,
swa_metadata.block_size,
)
return q
@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
return num_heads
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
# XPU uses BF16 reference wo_a path (same as ROCm).
from vllm.models.deepseek_v4.amd.rocm import rocm_inv_rope_einsum
z = rocm_inv_rope_einsum(
self.rotary_emb,
o,
positions,
self.rope_head_dim,
self.n_local_groups,
self.o_lora_rank,
self.wo_a,
)
return self.wo_b(z.flatten(1))
def forward_mqa(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
assert output.shape == q.shape, (
f"output buffer shape {output.shape} must match q shape {q.shape}"
)
assert output.dtype == q.dtype, (
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
# Warmup dummy run: reserve workspace, skip actual kernels.
swa_only = self.compress_ratio <= 1
N = (
0
if swa_only
else (self.max_model_len + self.compress_ratio - 1)
// self.compress_ratio
)
M = N + self.window_size + self.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
output.zero_()
return
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
DeepseekV4FlashMLAMetadata | None, attn_metadata.get(self.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = self.compress_ratio <= 1
self_kv_cache = self.kv_cache if not swa_only else None
swa_kv_cache = self.swa_cache_layer.kv_cache
# Split prefill and decode
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
if num_prefills > 0:
self._forward_prefill(
q=q[num_decode_tokens:],
positions=positions[num_decode_tokens:],
compressed_k_cache=self_kv_cache,
swa_k_cache=swa_kv_cache,
output=output[num_decode_tokens:],
attn_metadata=flashmla_metadata,
swa_metadata=swa_metadata,
)
if num_decodes > 0:
self._forward_decode(
q=q[:num_decode_tokens],
kv_cache=self_kv_cache,
swa_metadata=swa_metadata,
attn_metadata=flashmla_metadata,
swa_only=swa_only,
output=output[:num_decode_tokens],
)
def _forward_decode(
self,
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_metadata: "DeepseekSparseSWAMetadata",
attn_metadata: DeepseekV4FlashMLAMetadata | None,
swa_only: bool,
output: torch.Tensor,
) -> None:
num_decodes = swa_metadata.num_decodes
num_decode_tokens = swa_metadata.num_decode_tokens
topk_indices = None
topk_lens = None
if not swa_only:
assert attn_metadata is not None
assert swa_metadata.is_valid_token is not None
block_size = attn_metadata.block_size // self.compress_ratio
is_valid = swa_metadata.is_valid_token[:num_decode_tokens]
if self.compress_ratio == 4:
# C4A: local indices differ per layer (filled by Indexer).
assert self.topk_indices_buffer is not None
global_indices, topk_lens = compute_global_topk_indices_and_lens(
self.topk_indices_buffer[:num_decode_tokens],
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
is_valid,
)
topk_indices = global_indices.view(num_decode_tokens, 1, -1)
else:
# C128A: pre-computed during metadata build.
topk_indices = attn_metadata.c128a_global_decode_topk_indices
topk_lens = attn_metadata.c128a_decode_topk_lens
swa_indices = swa_metadata.decode_swa_indices
swa_lens = swa_metadata.decode_swa_lens
assert swa_indices is not None and swa_lens is not None
xpu_sparse_decode_fp8(
q=q,
kv_cache=kv_cache,
swa_kv_cache=self.swa_cache_layer.kv_cache,
swa_only=swa_only,
topk_indices=topk_indices,
topk_lens=topk_lens,
swa_indices=swa_indices,
swa_lens=swa_lens,
attn_sink=self.attn_sink,
softmax_scale=self.scale,
head_dim=self.head_dim,
nope_head_dim=self.nope_head_dim,
rope_head_dim=self.rope_head_dim,
out=output,
)
def _forward_prefill(
self,
q: torch.Tensor,
positions: torch.Tensor,
compressed_k_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
output: torch.Tensor,
attn_metadata: DeepseekV4FlashMLAMetadata | None,
swa_metadata: "DeepseekSparseSWAMetadata",
) -> None:
swa_only = attn_metadata is None
num_prefills = swa_metadata.num_prefills
num_prefill_tokens = swa_metadata.num_prefill_tokens
num_decodes = swa_metadata.num_decodes
num_decode_tokens = swa_metadata.num_decode_tokens
# Use pre-computed prefill metadata.
seq_lens = swa_metadata.prefill_seq_lens
gather_lens = swa_metadata.prefill_gather_lens
assert seq_lens is not None
assert gather_lens is not None
# Derive prefill-local token offsets from the full query_start_loc_cpu.
query_start_loc_cpu = swa_metadata.query_start_loc_cpu
query_start_loc = swa_metadata.query_start_loc
assert query_start_loc_cpu is not None
assert query_start_loc is not None
prefill_token_base = query_start_loc_cpu[num_decodes]
if not swa_only:
if self.compress_ratio == 4:
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
topk_indices = topk_indices[:num_prefill_tokens]
else:
# C128A: pre-computed during metadata build.
assert attn_metadata is not None
topk_indices = attn_metadata.c128a_prefill_topk_indices
top_k = topk_indices.shape[-1]
N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
N = 0
M = N + self.window_size + self.max_num_batched_tokens
chunk_size_const = self.PREFILL_CHUNK_SIZE
num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const
workspace_manager = current_workspace_manager()
kv = workspace_manager.get_simultaneous(
((chunk_size_const, M, q.shape[-1]), torch.bfloat16),
)[0]
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size_const
chunk_end = min(chunk_start + chunk_size_const, num_prefills)
chunk_size = chunk_end - chunk_start
if not swa_only:
# Gather compressed KV
assert attn_metadata is not None
block_table = attn_metadata.block_table[num_decodes:]
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
gather_lens=None,
block_table=block_table[chunk_start:chunk_end],
block_size=attn_metadata.block_size // self.compress_ratio,
offset=0,
)
# Gather SWA KV
swa_block_table = swa_metadata.block_table[num_decodes:]
dequantize_and_gather_k_cache(
kv[:chunk_size],
swa_k_cache,
seq_lens=seq_lens[chunk_start:chunk_end],
gather_lens=gather_lens[chunk_start:chunk_end],
block_table=swa_block_table[chunk_start:chunk_end],
block_size=swa_metadata.block_size,
offset=N,
)
# Combine the topk indices and SWA indices for gathered KV cache
query_start = (
query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base
)
query_end = (
query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base
)
combined_indices, combined_lens = combine_topk_swa_indices(
topk_indices[query_start:query_end],
query_start_loc[
num_decodes + chunk_start : num_decodes + chunk_end + 1
],
seq_lens[chunk_start:chunk_end],
gather_lens[chunk_start:chunk_end],
self.window_size,
self.compress_ratio,
top_k,
M,
N,
)
kv_ws = kv[:chunk_size].reshape(-1, 1, q.shape[-1])
out, _, _ = triton_bf16_mla_sparse_interface(
q=q[query_start:query_end],
kv=kv_ws,
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
d_v=q.shape[-1],
block_dpe=0,
)
output[query_start:query_end] = out