Skip to content

vllm_gaudi.ops.hpu_multihead_attn

class HpuMultiHeadAttention(MultiHeadAttention):

def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
) -> torch.Tensor:
    """Input shape: 
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)

    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    query, key, value = (x.transpose(1, 2) for x in (query, key, value))

    from vllm_gaudi.extension.runtime import get_config

    if get_config().prompt_attn_impl == 'fsdpa_impl':

        from vllm_gaudi.extension.utils import ModuleFusedSDPA
        import vllm_gaudi.extension.kernels as kernels

        HPUFusedSDPA = kernels.fsdpa()
        fsdpa_op = ModuleFusedSDPA(HPUFusedSDPA)

        out = fsdpa_op(query,
                       key,
                       value,
                       None,
                       dropout_p=0.0,
                       is_causal=False,
                       scale=self.scale,
                       softmax_mode="fast",
                       recompute_mode=True,
                       valid_sequence_lengths=None)
    else:
        out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

    out = out.transpose(1, 2)
    return out.reshape(bsz, q_len, -1)

layer.MultiHeadAttention = HpuMultiHeadAttention