Skip to content

vllm_gaudi.ops.hpu_mm_encoder_attention

HpuMMEncoderAttention

Bases: MMEncoderAttention

Source code in vllm_gaudi/ops/hpu_mm_encoder_attention.py
@MMEncoderAttention.register_oot()
class HpuMMEncoderAttention(MMEncoderAttention):

    def _forward_sdpa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

        query, key, value = (x.transpose(1, 2) for x in (query, key, value))

        from vllm_gaudi.extension.runtime import get_config

        if get_config().prompt_attn_impl == 'fsdpa_impl':

            from vllm_gaudi.extension.utils import ModuleFusedSDPA
            import vllm_gaudi.extension.kernels as kernels

            HPUFusedSDPA = kernels.fsdpa()
            fsdpa_op = ModuleFusedSDPA(HPUFusedSDPA)

            if cu_seqlens is None:
                out = fsdpa_op(query,
                               key,
                               value,
                               None,
                               dropout_p=0.0,
                               is_causal=False,
                               scale=self.scale,
                               softmax_mode="fast",
                               recompute_mode=True,
                               valid_sequence_lengths=None)
            else:
                cu_seqlens_list = cu_seqlens.to("cpu").tolist()
                outputs = []
                start_idx = 0
                # Iterate through sequence boundaries using tensor operations
                for i in range(len(cu_seqlens_list) - 1):
                    end_idx = cu_seqlens_list[i + 1]

                    # Extract chunks using tensor indexing
                    q_i = query[:, :, start_idx:end_idx, :]
                    k_i = key[:, :, start_idx:end_idx, :]
                    v_i = value[:, :, start_idx:end_idx, :]

                    # Process each chunk
                    output_i = fsdpa_op(q_i,
                                        k_i,
                                        v_i,
                                        None,
                                        dropout_p=0.0,
                                        is_causal=False,
                                        scale=self.scale,
                                        softmax_mode="fast",
                                        recompute_mode=True,
                                        valid_sequence_lengths=None)
                    outputs.append(output_i)
                    start_idx = end_idx
                out = torch.cat(outputs, dim=2)
                return out.transpose(1, 2)
        else:
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

        out = out.transpose(1, 2)
        return out.reshape(bsz, q_len, -1)

_forward_sdpa

_forward_sdpa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm_gaudi/ops/hpu_mm_encoder_attention.py
def _forward_sdpa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)

    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    if (num_repeat := self.num_queries_per_kv) > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_repeat, dim=2)
        value = torch.repeat_interleave(value, num_repeat, dim=2)

    query, key, value = (x.transpose(1, 2) for x in (query, key, value))

    from vllm_gaudi.extension.runtime import get_config

    if get_config().prompt_attn_impl == 'fsdpa_impl':

        from vllm_gaudi.extension.utils import ModuleFusedSDPA
        import vllm_gaudi.extension.kernels as kernels

        HPUFusedSDPA = kernels.fsdpa()
        fsdpa_op = ModuleFusedSDPA(HPUFusedSDPA)

        if cu_seqlens is None:
            out = fsdpa_op(query,
                           key,
                           value,
                           None,
                           dropout_p=0.0,
                           is_causal=False,
                           scale=self.scale,
                           softmax_mode="fast",
                           recompute_mode=True,
                           valid_sequence_lengths=None)
        else:
            cu_seqlens_list = cu_seqlens.to("cpu").tolist()
            outputs = []
            start_idx = 0
            # Iterate through sequence boundaries using tensor operations
            for i in range(len(cu_seqlens_list) - 1):
                end_idx = cu_seqlens_list[i + 1]

                # Extract chunks using tensor indexing
                q_i = query[:, :, start_idx:end_idx, :]
                k_i = key[:, :, start_idx:end_idx, :]
                v_i = value[:, :, start_idx:end_idx, :]

                # Process each chunk
                output_i = fsdpa_op(q_i,
                                    k_i,
                                    v_i,
                                    None,
                                    dropout_p=0.0,
                                    is_causal=False,
                                    scale=self.scale,
                                    softmax_mode="fast",
                                    recompute_mode=True,
                                    valid_sequence_lengths=None)
                outputs.append(output_i)
                start_idx = end_idx
            out = torch.cat(outputs, dim=2)
            return out.transpose(1, 2)
    else:
        out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)

    out = out.transpose(1, 2)
    return out.reshape(bsz, q_len, -1)