Skip to content

vllm_gaudi.extension.utils

B2BMatmul

Bases: Matmul

Specialized alias for batch2block and block2batch matmul operations.

This class remains functionally identical to Matmul but is used to semantically mark B2B-related matmuls. This enables the system to apply the fix that uses the B2B output measurements as the input measurements during calibration, avoiding corrupted scales from the KV‑cache.

Source code in vllm_gaudi/extension/utils.py
class B2BMatmul(Matmul):
    """Specialized alias for batch2block and block2batch matmul operations.

    This class remains functionally identical to ``Matmul`` but is used to
    semantically mark B2B-related matmuls. This enables the system to apply the
    fix that uses the B2B output measurements as the input measurements during
    calibration, avoiding corrupted scales from the KV‑cache.
    """

    def __init__(self):
        super().__init__()

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super().__init__()

FP8Matmul

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class FP8Matmul(torch.nn.Module):

    def __init__(
        self,
        scale_input=1.0,
        scale_other=1.0,
    ):
        super().__init__()
        self.scale_input = scale_input
        self.scale_other = scale_other

    def quant_input(self, x, scale):
        return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

    def matmul_fp8(self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None):
        return torch.ops.hpu.fp8_gemm_v2(
            A=x,
            trans_A=False,
            B=other,
            trans_B=False,
            D=None,
            out_dtype=out_dtype,
            A_scale_inv=scale_input_inv,
            B_scale_inv=scale_other_inv,
            bias=None,
            accumulate=False,
        )

    def forward(self, input, other, **kwargs):
        qinput = self.quant_input(input, self.scale_input)
        qother = self.quant_input(other, self.scale_other)
        output = self.matmul_fp8(
            qinput,
            qother,
            out_dtype=torch.bfloat16,
            scale_input_inv=1.0 / self.scale_input,
            scale_other_inv=1.0 / self.scale_other,
        )
        return output

scale_input instance-attribute

scale_input = scale_input

scale_other instance-attribute

scale_other = scale_other

__init__

__init__(scale_input=1.0, scale_other=1.0)
Source code in vllm_gaudi/extension/utils.py
def __init__(
    self,
    scale_input=1.0,
    scale_other=1.0,
):
    super().__init__()
    self.scale_input = scale_input
    self.scale_other = scale_other

forward

forward(input, other, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def forward(self, input, other, **kwargs):
    qinput = self.quant_input(input, self.scale_input)
    qother = self.quant_input(other, self.scale_other)
    output = self.matmul_fp8(
        qinput,
        qother,
        out_dtype=torch.bfloat16,
        scale_input_inv=1.0 / self.scale_input,
        scale_other_inv=1.0 / self.scale_other,
    )
    return output

matmul_fp8

matmul_fp8(
    x,
    other,
    out_dtype,
    scale_input_inv=None,
    scale_other_inv=None,
)
Source code in vllm_gaudi/extension/utils.py
def matmul_fp8(self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None):
    return torch.ops.hpu.fp8_gemm_v2(
        A=x,
        trans_A=False,
        B=other,
        trans_B=False,
        D=None,
        out_dtype=out_dtype,
        A_scale_inv=scale_input_inv,
        B_scale_inv=scale_other_inv,
        bias=None,
        accumulate=False,
    )

quant_input

quant_input(x, scale)
Source code in vllm_gaudi/extension/utils.py
def quant_input(self, x, scale):
    return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

Matmul

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class Matmul(torch.nn.Module):

    def __init__(self):
        super(Matmul, self).__init__()

    def forward(self, x, y, **kwargs):
        return torch.matmul(x, y, **kwargs)

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super(Matmul, self).__init__()

forward

forward(x, y, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def forward(self, x, y, **kwargs):
    return torch.matmul(x, y, **kwargs)

ModuleFP8FusedSDPA

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class ModuleFP8FusedSDPA(torch.nn.Module):

    def __init__(self, fusedSDPA):
        super().__init__()
        assert fusedSDPA is not None, f'FP8 fusedSDPA kernel is None'
        self.fp8_fused_sdpa = fusedSDPA

        # set the descale_amax and scale_amax 1.0 temporarily
        self.descale_amax = torch.tensor(1.0, dtype=torch.float32)
        self.scale_amax = torch.tensor(1.0, dtype=torch.float32)
        self.scale_q = torch.tensor(1.0, dtype=torch.float32)
        self.scale_k = torch.tensor(1.0, dtype=torch.float32)
        self.scale_v = torch.tensor(1.0, dtype=torch.float32)
        self.d_scale_q = torch.tensor(1.0, dtype=torch.float32)
        self.d_scale_k = torch.tensor(1.0, dtype=torch.float32)
        self.d_scale_v = torch.tensor(1.0, dtype=torch.float32)
        self.d_scale_output = torch.tensor(1.0, dtype=torch.float32)
        self._sliced_module = SlicedFP8FusedSDPA(parent=self)

    def quant_input(self, x, scale):
        return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

    def forward(
        self,
        query,
        key,
        value,
        attn_mask,
        dropout_p,
        is_causal,
        scale,
        softmax_mode,
        recompute_mode,
        valid_sequence_lengths,
        padding_side="left",
        window_size=None,
    ):

        qinput = self.quant_input(query, self.scale_q).detach()
        kinput = self.quant_input(key, self.scale_k).detach()
        vinput = self.quant_input(value, self.scale_v).detach()

        bs = query.shape[0]
        q_len = query.shape[-2]
        kv_len = key.shape[-2]
        if (self._sliced_module.enable_slicing and kv_len >= self._sliced_module.slice_thld \
                and bs == 1  # bs should be 1 for chunked prefill
                and q_len != kv_len  # normal causal prefill route to the default dispatch for better performance
                and is_causal and attn_mask is not None  # only supports causal attention with mask
                and padding_side == 'right'  # currently only supports right padding for the chunks that may have padding
                and window_size is None  # slicing is not compatible with sliding window attention
            ):
            return self._sliced_module(qinput, kinput, vinput, attn_mask, dropout_p, is_causal, scale,
                                       softmax_mode).to(query.dtype)

        if is_causal and attn_mask is not None:
            # TODO: causal + attn_bias is not yet supported
            is_causal = False
            valid_sequence_lengths = None

        results = self.fp8_fused_sdpa(
            qinput,
            kinput,
            vinput,
            attn_mask=attn_mask,
            dropout_p=dropout_p,
            is_causal=is_causal,
            scale=scale,
            softmax_mode=softmax_mode,
            d_scale_q=self.d_scale_q,
            d_scale_k=self.d_scale_k,
            d_scale_v=self.d_scale_v,
            q_scale_s=self.scale_amax,
            # q_scale_o=1 / 1.0,
            d_scale_s=self.descale_amax,
            is_amax_s=False,
            valid_seq_len=valid_sequence_lengths,
            seq_padding_type=padding_side,
        )

        output = results[0]
        return output

_sliced_module instance-attribute

_sliced_module = SlicedFP8FusedSDPA(parent=self)

d_scale_k instance-attribute

d_scale_k = tensor(1.0, dtype=float32)

d_scale_output instance-attribute

d_scale_output = tensor(1.0, dtype=float32)

d_scale_q instance-attribute

d_scale_q = tensor(1.0, dtype=float32)

d_scale_v instance-attribute

d_scale_v = tensor(1.0, dtype=float32)

descale_amax instance-attribute

descale_amax = tensor(1.0, dtype=float32)

fp8_fused_sdpa instance-attribute

fp8_fused_sdpa = fusedSDPA

scale_amax instance-attribute

scale_amax = tensor(1.0, dtype=float32)

scale_k instance-attribute

scale_k = tensor(1.0, dtype=float32)

scale_q instance-attribute

scale_q = tensor(1.0, dtype=float32)

scale_v instance-attribute

scale_v = tensor(1.0, dtype=float32)

__init__

__init__(fusedSDPA)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, fusedSDPA):
    super().__init__()
    assert fusedSDPA is not None, f'FP8 fusedSDPA kernel is None'
    self.fp8_fused_sdpa = fusedSDPA

    # set the descale_amax and scale_amax 1.0 temporarily
    self.descale_amax = torch.tensor(1.0, dtype=torch.float32)
    self.scale_amax = torch.tensor(1.0, dtype=torch.float32)
    self.scale_q = torch.tensor(1.0, dtype=torch.float32)
    self.scale_k = torch.tensor(1.0, dtype=torch.float32)
    self.scale_v = torch.tensor(1.0, dtype=torch.float32)
    self.d_scale_q = torch.tensor(1.0, dtype=torch.float32)
    self.d_scale_k = torch.tensor(1.0, dtype=torch.float32)
    self.d_scale_v = torch.tensor(1.0, dtype=torch.float32)
    self.d_scale_output = torch.tensor(1.0, dtype=torch.float32)
    self._sliced_module = SlicedFP8FusedSDPA(parent=self)

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
)
Source code in vllm_gaudi/extension/utils.py
def forward(
    self,
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
):

    qinput = self.quant_input(query, self.scale_q).detach()
    kinput = self.quant_input(key, self.scale_k).detach()
    vinput = self.quant_input(value, self.scale_v).detach()

    bs = query.shape[0]
    q_len = query.shape[-2]
    kv_len = key.shape[-2]
    if (self._sliced_module.enable_slicing and kv_len >= self._sliced_module.slice_thld \
            and bs == 1  # bs should be 1 for chunked prefill
            and q_len != kv_len  # normal causal prefill route to the default dispatch for better performance
            and is_causal and attn_mask is not None  # only supports causal attention with mask
            and padding_side == 'right'  # currently only supports right padding for the chunks that may have padding
            and window_size is None  # slicing is not compatible with sliding window attention
        ):
        return self._sliced_module(qinput, kinput, vinput, attn_mask, dropout_p, is_causal, scale,
                                   softmax_mode).to(query.dtype)

    if is_causal and attn_mask is not None:
        # TODO: causal + attn_bias is not yet supported
        is_causal = False
        valid_sequence_lengths = None

    results = self.fp8_fused_sdpa(
        qinput,
        kinput,
        vinput,
        attn_mask=attn_mask,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=scale,
        softmax_mode=softmax_mode,
        d_scale_q=self.d_scale_q,
        d_scale_k=self.d_scale_k,
        d_scale_v=self.d_scale_v,
        q_scale_s=self.scale_amax,
        # q_scale_o=1 / 1.0,
        d_scale_s=self.descale_amax,
        is_amax_s=False,
        valid_seq_len=valid_sequence_lengths,
        seq_padding_type=padding_side,
    )

    output = results[0]
    return output

quant_input

quant_input(x, scale)
Source code in vllm_gaudi/extension/utils.py
def quant_input(self, x, scale):
    return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

ModuleFusedSDPA

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class ModuleFusedSDPA(torch.nn.Module):

    def __init__(self, fusedSDPA):
        super().__init__()
        assert fusedSDPA is not None, f'fusedSDPA kernel is None'
        self._hpu_kernel_fsdpa = fusedSDPA
        self._sliced_module = SlicedFusedSDPA()

    def forward(
        self,
        query,
        key,
        value,
        attn_mask,
        dropout_p,
        is_causal,
        scale,
        softmax_mode,
        recompute_mode,
        valid_sequence_lengths,
        padding_side="left",
        window_size=None,
        sinks=None,
    ):
        if (self._sliced_module.enable_slicing
                and key.shape[-2] >= self._sliced_module.slice_thld  # apply for kv_len >= slice_thld only
                and query.shape[0] == 1  # bs should be 1 for prefix-prefill
                and query.shape[-2] != key.shape[-2]  # normal prefill with q_len == kv_len route to the default
                and is_causal and attn_mask is not None  # only supports causal attention with mask
                and padding_side == 'right'  # supports right padding only for the chunks that may have padding
                and window_size is None  # slicing is not compatible with sliding window attention
                and sinks is None  # slicing is not compatible with kernel fusion with sinks
            ):
            return self._sliced_module(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)

        if is_causal and attn_mask is not None:
            # TODO: causal + attn_bias is not yet supported
            is_causal = False
            valid_sequence_lengths = None

        if window_size is not None:
            return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                                recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                                window_size, sinks)
        else:
            return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                                recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                                (-1, -1), sinks)

_hpu_kernel_fsdpa instance-attribute

_hpu_kernel_fsdpa = fusedSDPA

_sliced_module instance-attribute

_sliced_module = SlicedFusedSDPA()

__init__

__init__(fusedSDPA)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, fusedSDPA):
    super().__init__()
    assert fusedSDPA is not None, f'fusedSDPA kernel is None'
    self._hpu_kernel_fsdpa = fusedSDPA
    self._sliced_module = SlicedFusedSDPA()

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
    sinks=None,
)
Source code in vllm_gaudi/extension/utils.py
def forward(
    self,
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
    sinks=None,
):
    if (self._sliced_module.enable_slicing
            and key.shape[-2] >= self._sliced_module.slice_thld  # apply for kv_len >= slice_thld only
            and query.shape[0] == 1  # bs should be 1 for prefix-prefill
            and query.shape[-2] != key.shape[-2]  # normal prefill with q_len == kv_len route to the default
            and is_causal and attn_mask is not None  # only supports causal attention with mask
            and padding_side == 'right'  # supports right padding only for the chunks that may have padding
            and window_size is None  # slicing is not compatible with sliding window attention
            and sinks is None  # slicing is not compatible with kernel fusion with sinks
        ):
        return self._sliced_module(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)

    if is_causal and attn_mask is not None:
        # TODO: causal + attn_bias is not yet supported
        is_causal = False
        valid_sequence_lengths = None

    if window_size is not None:
        return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                            recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                            window_size, sinks)
    else:
        return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                            recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                            (-1, -1), sinks)

SlicedFP8FusedSDPA

Bases: SlicedFusedSDPABase

Standalone module for FP8 sliced FusedSDPA.

Like :class:SlicedFusedSDPA, extracting the sliced path enables wrapping with torch.compile or ht.hpu.wrap_in_hpu_graph. Expects pre-quantized FP8 inputs; dequantises chunk outputs to BF16/FP32 before the online-softmax rescaling merge.

Source code in vllm_gaudi/extension/utils.py
class SlicedFP8FusedSDPA(SlicedFusedSDPABase):
    """Standalone module for FP8 sliced FusedSDPA.

    Like :class:`SlicedFusedSDPA`, extracting the sliced path enables
    wrapping with ``torch.compile`` or ``ht.hpu.wrap_in_hpu_graph``.
    Expects pre-quantized FP8 inputs; dequantises chunk outputs to
    BF16/FP32 before the online-softmax rescaling merge.
    """

    def __init__(self, parent):
        super().__init__()
        # Store parent reference without registering as a submodule
        # to avoid circular module graph while sharing scale tensors.
        object.__setattr__(self, '_parent', parent)

    def _dequant_output(self, output):
        return torch.ops.hpu.cast_from_fp8(output, self._parent.d_scale_output, torch.bfloat16)

    def _fp8_fsdpa_fwd(self, q, k, v, attn_mask, dropout_p, scale, is_causal, softmax_mode):
        results = torch.ops.hpu.fp8_sdpa_recomp_fwd(
            q,
            k,
            v,
            attn_mask,
            dropout_p,
            scale,
            is_causal,
            True,  # requires_backward
            softmax_mode,
            self._parent.d_scale_q,
            self._parent.d_scale_k,
            self._parent.d_scale_v,
            self._parent.scale_amax,
            self._parent.d_scale_output,
            self._parent.descale_amax,
            False,  # is_amax_s
            False,  # is_amax_o
            None,  # valid_seq_len
            "right",  # padding_side
            (-1, -1),  # window_size
            None,  # sinks
        )
        return results

    def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
        assert is_causal and attn_mask is not None

        from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
        gqa = is_gqa(query, key)
        if gqa:
            q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask)
        else:
            q, k, v, attn_mask = (query, key, value, attn_mask)
        softmax_mode = softmax_mode if softmax_mode == "fp32" else "fast"
        if scale is None:
            scale = 1.0 / (query.shape[-1]**0.5)

        def chunk_kernel(q_c, k_c, v_c, mask_c, dp, sc, is_c, sm):
            res = self._fp8_fsdpa_fwd(q_c, k_c, v_c, mask_c, dp, sc, is_c, sm)
            out, m, linv = tuple(gqa_output_reshape(x) if gqa else x for x in res[:3])
            m = m.to(torch.float32)
            linv = linv.to(torch.float32) * (128.0 if sm == "fast" else 1.0)
            out = self._dequant_output(out).to(torch.float32)
            return out, m, linv

        return self._chunked_attention(q, k, v, attn_mask, dropout_p, scale, softmax_mode, chunk_kernel)

__init__

__init__(parent)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, parent):
    super().__init__()
    # Store parent reference without registering as a submodule
    # to avoid circular module graph while sharing scale tensors.
    object.__setattr__(self, '_parent', parent)

_dequant_output

_dequant_output(output)
Source code in vllm_gaudi/extension/utils.py
def _dequant_output(self, output):
    return torch.ops.hpu.cast_from_fp8(output, self._parent.d_scale_output, torch.bfloat16)

_fp8_fsdpa_fwd

_fp8_fsdpa_fwd(
    q,
    k,
    v,
    attn_mask,
    dropout_p,
    scale,
    is_causal,
    softmax_mode,
)
Source code in vllm_gaudi/extension/utils.py
def _fp8_fsdpa_fwd(self, q, k, v, attn_mask, dropout_p, scale, is_causal, softmax_mode):
    results = torch.ops.hpu.fp8_sdpa_recomp_fwd(
        q,
        k,
        v,
        attn_mask,
        dropout_p,
        scale,
        is_causal,
        True,  # requires_backward
        softmax_mode,
        self._parent.d_scale_q,
        self._parent.d_scale_k,
        self._parent.d_scale_v,
        self._parent.scale_amax,
        self._parent.d_scale_output,
        self._parent.descale_amax,
        False,  # is_amax_s
        False,  # is_amax_o
        None,  # valid_seq_len
        "right",  # padding_side
        (-1, -1),  # window_size
        None,  # sinks
    )
    return results

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
)
Source code in vllm_gaudi/extension/utils.py
def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
    assert is_causal and attn_mask is not None

    from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
    gqa = is_gqa(query, key)
    if gqa:
        q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask)
    else:
        q, k, v, attn_mask = (query, key, value, attn_mask)
    softmax_mode = softmax_mode if softmax_mode == "fp32" else "fast"
    if scale is None:
        scale = 1.0 / (query.shape[-1]**0.5)

    def chunk_kernel(q_c, k_c, v_c, mask_c, dp, sc, is_c, sm):
        res = self._fp8_fsdpa_fwd(q_c, k_c, v_c, mask_c, dp, sc, is_c, sm)
        out, m, linv = tuple(gqa_output_reshape(x) if gqa else x for x in res[:3])
        m = m.to(torch.float32)
        linv = linv.to(torch.float32) * (128.0 if sm == "fast" else 1.0)
        out = self._dequant_output(out).to(torch.float32)
        return out, m, linv

    return self._chunked_attention(q, k, v, attn_mask, dropout_p, scale, softmax_mode, chunk_kernel)

SlicedFusedSDPA

Bases: SlicedFusedSDPABase

Standalone module for BF16 sliced FusedSDPA.

Extracting the sliced attention path into its own nn.Module allows it to be wrapped with torch.compile, ht.hpu.wrap_in_hpu_graph, or any other module-level wrapper independently of the dispatch logic in :class:ModuleFusedSDPA.

Source code in vllm_gaudi/extension/utils.py
class SlicedFusedSDPA(SlicedFusedSDPABase):
    """Standalone module for BF16 sliced FusedSDPA.

    Extracting the sliced attention path into its own ``nn.Module`` allows it
    to be wrapped with ``torch.compile``, ``ht.hpu.wrap_in_hpu_graph``, or
    any other module-level wrapper independently of the dispatch logic in
    :class:`ModuleFusedSDPA`.
    """

    def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
        assert is_causal and attn_mask is not None

        from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
        gqa = is_gqa(query, key)
        if gqa:
            q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask)
        else:
            q, k, v, attn_mask = (query, key, value, attn_mask)
        if scale is None:
            scale = 1.0 / (query.shape[-1]**0.5)

        def chunk_kernel(q_c, k_c, v_c, mask_c, dp, sc, is_c, sm):
            res = torch.ops.hpu.sdpa_recomp_fwd(q_c, k_c, v_c, mask_c, dp, sc, is_c, True, sm, None, 'right')
            out, m, linv = tuple((gqa_output_reshape(x) if gqa else x).to(torch.float32) for x in res[:3])
            return out, m, linv

        output = self._chunked_attention(q, k, v, attn_mask, dropout_p, scale, softmax_mode, chunk_kernel)
        return output.to(q.dtype)

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
)
Source code in vllm_gaudi/extension/utils.py
def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
    assert is_causal and attn_mask is not None

    from habana_frameworks.torch.hpex.kernels.FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
    gqa = is_gqa(query, key)
    if gqa:
        q, k, v, attn_mask = gqa_input_reshape_fwd(query, key, value, attn_mask)
    else:
        q, k, v, attn_mask = (query, key, value, attn_mask)
    if scale is None:
        scale = 1.0 / (query.shape[-1]**0.5)

    def chunk_kernel(q_c, k_c, v_c, mask_c, dp, sc, is_c, sm):
        res = torch.ops.hpu.sdpa_recomp_fwd(q_c, k_c, v_c, mask_c, dp, sc, is_c, True, sm, None, 'right')
        out, m, linv = tuple((gqa_output_reshape(x) if gqa else x).to(torch.float32) for x in res[:3])
        return out, m, linv

    output = self._chunked_attention(q, k, v, attn_mask, dropout_p, scale, softmax_mode, chunk_kernel)
    return output.to(q.dtype)

SlicedFusedSDPABase

Bases: Module

Base class for sliced FusedSDPA modules.

Encapsulates the common slicing initialization (chunk size, padded chunk counts, graph-break setup) shared by :class:SlicedFusedSDPA and :class:SlicedFP8FusedSDPA.

Source code in vllm_gaudi/extension/utils.py
class SlicedFusedSDPABase(torch.nn.Module):
    """Base class for sliced FusedSDPA modules.

    Encapsulates the common slicing initialization (chunk size, padded chunk
    counts, graph-break setup) shared by :class:`SlicedFusedSDPA` and
    :class:`SlicedFP8FusedSDPA`.
    """

    def __init__(self):
        super().__init__()
        self.enable_slicing = self._setup_slicing()

    def _setup_slicing(self) -> bool:
        enable_slicing = get_config().enable_fsdpa_slicing
        if not enable_slicing:
            return False

        if get_config().bucketing_strategy != 'pad':
            logger().warning_once(
                'FusedSDPA slicing is only compatible with padding-based bucketing strategy, slicing in FusedSDPA will be disabled.'
            )
            return False

        if get_config().merged_prefill:
            logger().warning_once(
                'FusedSDPA slicing is not compatible with merged prefill, slicing in FusedSDPA will be disabled.')
            return False

        if not get_config().use_bucketing:
            logger().warning_once(
                'FusedSDPA slicing requires bucketing to be enabled, slicing in FusedSDPA will be disabled.')
            return False

        from vllm_gaudi.extension.bucketing.common import get_bucketing_manager
        bucketing_manager = get_bucketing_manager()
        assert bucketing_manager is not None and bucketing_manager.initialized, 'Bucketing manager should be instantiated and initialized to enable FusedSDPA slicing.'

        from vllm_gaudi.extension.bucketing.padding_aware import PaddingAwareBucketingStrategy
        strategy = bucketing_manager.get_bucketing_strategy()
        assert isinstance(
            strategy,
            PaddingAwareBucketingStrategy), 'Bucketing strategy should be Padding-Aware to enable FusedSDPA slicing.'

        max_num_batched_tokens = bucketing_manager.max_num_batched_tokens
        block_size = bucketing_manager.block_size
        slice_thld_default = min(max_num_batched_tokens, 8192)
        slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default)))
        assert slice_thld > block_size, 'Invalid FusedSDPA slice sequence length threshold, the threshold should be greater than the block size.'
        assert slice_thld >= 1024, 'The FusedSDPA slice sequence length threshold should be greater than or equal to 1024 to ensure the chunk sizes are valid for the attention kernel.'
        if slice_thld < slice_thld_default:
            logger().warning_once(
                f'The FusedSDPA slice sequence length threshold {slice_thld} is less than the default {slice_thld_default} which is not recommended.'
            )

        # defaults to half of the threshold and round up by 1024
        chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024
        chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default)))
        if chunk_size % 1024 != 0:
            chunk_size = math.ceil(chunk_size / 1024) * 1024
            logger().warning_once('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.')
        assert chunk_size > block_size and chunk_size <= slice_thld, 'Invalid FusedSDPA slice chunk size, the chunk size should be between the block size and the slice sequence length threshold.'

        self.slice_thld = slice_thld
        self.chunk_size = chunk_size

        # should align with the default value in PaddingAwareBucketingStrategy
        max_query_pad_default = math.ceil(max_num_batched_tokens / 4)
        max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default)))
        assert max_query_pad >= block_size, 'Invalid max query padding, the max query padding should be greater than or equal to the block size.'
        self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size)

        # should align with the default value in PaddingAwareBucketingStrategy
        max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size)
        max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default)))
        self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size)

        import habana_frameworks.torch as ht
        is_lazy = ht.utils.internal.is_lazy()
        self._with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS",
                                            str(is_lazy)).strip().lower() in ['true', 't', '1', 'yes', 'y', 'on']
        if self._with_graph_breaks and not is_lazy:
            logger().warning_once('FusedSDPA slicing graph breaks are only supported in lazy mode. '
                                  'Disabling graph breaks for eager/compile mode to avoid Synapse compiler failures.')
            self._with_graph_breaks = False
        if self._with_graph_breaks:
            self._break_graph = ht.core.mark_step

        msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, "
               f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, "
               f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self._with_graph_breaks}.")
        logger().debug_once(msg)

        return True

    def maybe_break_graph(self):
        if self._with_graph_breaks:
            self._break_graph()

    @staticmethod
    def _merge_chunk(last_out, last_m, last_linv, chunk_out, chunk_m, chunk_linv):
        """Online softmax rescaling merge of two attention chunks."""
        if last_out is None or last_m is None or last_linv is None:
            return chunk_out, chunk_m, chunk_linv
        new_m = torch.maximum(last_m, chunk_m)
        last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
        chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
        new_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
        new_out = (last_linv_rescaled * new_linv) * last_out + (chunk_linv_rescaled * new_linv) * chunk_out
        return new_out, new_m, new_linv

    def _chunked_attention(self, q, k, v, attn_mask, dropout_p, scale, softmax_mode, chunk_kernel_fn):
        """Run chunked attention with online softmax rescaling.

        Args:
            q, k, v: Query, key, value tensors (after GQA reshape if needed).
            attn_mask: Attention mask tensor.
            dropout_p: Dropout probability.
            scale: Attention scale factor.
            softmax_mode: Softmax mode string.
            chunk_kernel_fn: Callable
                ``(q, k, v, mask, dropout_p, scale, is_causal, softmax_mode)``
                returning ``(out, m, linv)`` all as float32.

        Returns:
            Concatenated output tensor in float32.
        """
        q_len = q.shape[-2]
        kv_len = k.shape[-2]
        prefix_len = kv_len - q_len

        chunk_outputs = []
        num_q_chunks = math.ceil(q_len / self.chunk_size)
        num_prefix_chunks = math.ceil(prefix_len / self.chunk_size)
        for q_chunk_idx in range(num_q_chunks):
            q_start = q_len - (q_chunk_idx + 1) * self.chunk_size
            q_start = max(q_start, 0)
            q_end = q_len - q_chunk_idx * self.chunk_size
            q_chunk_size = q_end - q_start
            q_chunk = q[..., q_start:q_end, :].contiguous()

            last_out = None
            last_m = None
            last_linv = None

            # the causal part
            for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx):
                kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size
                kv_start = max(kv_start, prefix_len)
                kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size
                kv_chunk_size = kv_end - kv_start
                k_chunk = k[..., kv_start:kv_end, :].contiguous()
                v_chunk = v[..., kv_start:kv_end, :].contiguous()

                # Always pass explicit mask for the diagonal chunk (kv_chunk_idx==0)
                # to ensure numerical consistency. The kernel's is_causal=True path
                # uses a different internal algorithm that can diverge from the
                # explicit mask path even when both encode the same triangular pattern.
                # For non-diagonal chunks within the padded region, also pass mask.
                mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end].contiguous()
                              if kv_chunk_idx == 0 or kv_chunk_idx < self.num_padded_query_chunks else None)

                self.maybe_break_graph()

                chunk_out, chunk_m, chunk_linv = chunk_kernel_fn(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p,
                                                                 scale, False, softmax_mode)

                last_out, last_m, last_linv = self._merge_chunk(last_out, last_m, last_linv, chunk_out, chunk_m,
                                                                chunk_linv)

                self.maybe_break_graph()

            # the context part
            for kv_chunk_idx in range(num_prefix_chunks):
                kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size
                kv_start = max(kv_start, 0)
                kv_end = prefix_len - kv_chunk_idx * self.chunk_size
                k_chunk = k[..., kv_start:kv_end, :].contiguous()
                v_chunk = v[..., kv_start:kv_end, :].contiguous()
                # use mask only for the chunks that may have padding
                mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end].contiguous()
                              if kv_chunk_idx < self.num_padded_ctx_chunks else None)

                self.maybe_break_graph()

                chunk_out, chunk_m, chunk_linv = chunk_kernel_fn(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p,
                                                                 scale, False, softmax_mode)

                assert not (last_out is None or last_m is None or last_linv is None)
                last_out, last_m, last_linv = self._merge_chunk(last_out, last_m, last_linv, chunk_out, chunk_m,
                                                                chunk_linv)

                self.maybe_break_graph()
            chunk_outputs.append(last_out)
        chunk_outputs = list(reversed(chunk_outputs))
        return torch.cat(chunk_outputs, dim=-2)

enable_slicing instance-attribute

enable_slicing = _setup_slicing()

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super().__init__()
    self.enable_slicing = self._setup_slicing()

_chunked_attention

_chunked_attention(
    q,
    k,
    v,
    attn_mask,
    dropout_p,
    scale,
    softmax_mode,
    chunk_kernel_fn,
)

Run chunked attention with online softmax rescaling.

Parameters:

Name Type Description Default
q, k, v

Query, key, value tensors (after GQA reshape if needed).

required
attn_mask

Attention mask tensor.

required
dropout_p

Dropout probability.

required
scale

Attention scale factor.

required
softmax_mode

Softmax mode string.

required
chunk_kernel_fn

Callable (q, k, v, mask, dropout_p, scale, is_causal, softmax_mode) returning (out, m, linv) all as float32.

required

Returns:

Type Description

Concatenated output tensor in float32.

Source code in vllm_gaudi/extension/utils.py
def _chunked_attention(self, q, k, v, attn_mask, dropout_p, scale, softmax_mode, chunk_kernel_fn):
    """Run chunked attention with online softmax rescaling.

    Args:
        q, k, v: Query, key, value tensors (after GQA reshape if needed).
        attn_mask: Attention mask tensor.
        dropout_p: Dropout probability.
        scale: Attention scale factor.
        softmax_mode: Softmax mode string.
        chunk_kernel_fn: Callable
            ``(q, k, v, mask, dropout_p, scale, is_causal, softmax_mode)``
            returning ``(out, m, linv)`` all as float32.

    Returns:
        Concatenated output tensor in float32.
    """
    q_len = q.shape[-2]
    kv_len = k.shape[-2]
    prefix_len = kv_len - q_len

    chunk_outputs = []
    num_q_chunks = math.ceil(q_len / self.chunk_size)
    num_prefix_chunks = math.ceil(prefix_len / self.chunk_size)
    for q_chunk_idx in range(num_q_chunks):
        q_start = q_len - (q_chunk_idx + 1) * self.chunk_size
        q_start = max(q_start, 0)
        q_end = q_len - q_chunk_idx * self.chunk_size
        q_chunk_size = q_end - q_start
        q_chunk = q[..., q_start:q_end, :].contiguous()

        last_out = None
        last_m = None
        last_linv = None

        # the causal part
        for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx):
            kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.chunk_size
            kv_start = max(kv_start, prefix_len)
            kv_end = prefix_len + q_end - kv_chunk_idx * self.chunk_size
            kv_chunk_size = kv_end - kv_start
            k_chunk = k[..., kv_start:kv_end, :].contiguous()
            v_chunk = v[..., kv_start:kv_end, :].contiguous()

            # Always pass explicit mask for the diagonal chunk (kv_chunk_idx==0)
            # to ensure numerical consistency. The kernel's is_causal=True path
            # uses a different internal algorithm that can diverge from the
            # explicit mask path even when both encode the same triangular pattern.
            # For non-diagonal chunks within the padded region, also pass mask.
            mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end].contiguous()
                          if kv_chunk_idx == 0 or kv_chunk_idx < self.num_padded_query_chunks else None)

            self.maybe_break_graph()

            chunk_out, chunk_m, chunk_linv = chunk_kernel_fn(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p,
                                                             scale, False, softmax_mode)

            last_out, last_m, last_linv = self._merge_chunk(last_out, last_m, last_linv, chunk_out, chunk_m,
                                                            chunk_linv)

            self.maybe_break_graph()

        # the context part
        for kv_chunk_idx in range(num_prefix_chunks):
            kv_start = prefix_len - (kv_chunk_idx + 1) * self.chunk_size
            kv_start = max(kv_start, 0)
            kv_end = prefix_len - kv_chunk_idx * self.chunk_size
            k_chunk = k[..., kv_start:kv_end, :].contiguous()
            v_chunk = v[..., kv_start:kv_end, :].contiguous()
            # use mask only for the chunks that may have padding
            mask_chunk = (attn_mask[..., q_start:q_end, kv_start:kv_end].contiguous()
                          if kv_chunk_idx < self.num_padded_ctx_chunks else None)

            self.maybe_break_graph()

            chunk_out, chunk_m, chunk_linv = chunk_kernel_fn(q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p,
                                                             scale, False, softmax_mode)

            assert not (last_out is None or last_m is None or last_linv is None)
            last_out, last_m, last_linv = self._merge_chunk(last_out, last_m, last_linv, chunk_out, chunk_m,
                                                            chunk_linv)

            self.maybe_break_graph()
        chunk_outputs.append(last_out)
    chunk_outputs = list(reversed(chunk_outputs))
    return torch.cat(chunk_outputs, dim=-2)

_merge_chunk staticmethod

_merge_chunk(
    last_out,
    last_m,
    last_linv,
    chunk_out,
    chunk_m,
    chunk_linv,
)

Online softmax rescaling merge of two attention chunks.

Source code in vllm_gaudi/extension/utils.py
@staticmethod
def _merge_chunk(last_out, last_m, last_linv, chunk_out, chunk_m, chunk_linv):
    """Online softmax rescaling merge of two attention chunks."""
    if last_out is None or last_m is None or last_linv is None:
        return chunk_out, chunk_m, chunk_linv
    new_m = torch.maximum(last_m, chunk_m)
    last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
    chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
    new_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
    new_out = (last_linv_rescaled * new_linv) * last_out + (chunk_linv_rescaled * new_linv) * chunk_out
    return new_out, new_m, new_linv

_setup_slicing

_setup_slicing() -> bool
Source code in vllm_gaudi/extension/utils.py
def _setup_slicing(self) -> bool:
    enable_slicing = get_config().enable_fsdpa_slicing
    if not enable_slicing:
        return False

    if get_config().bucketing_strategy != 'pad':
        logger().warning_once(
            'FusedSDPA slicing is only compatible with padding-based bucketing strategy, slicing in FusedSDPA will be disabled.'
        )
        return False

    if get_config().merged_prefill:
        logger().warning_once(
            'FusedSDPA slicing is not compatible with merged prefill, slicing in FusedSDPA will be disabled.')
        return False

    if not get_config().use_bucketing:
        logger().warning_once(
            'FusedSDPA slicing requires bucketing to be enabled, slicing in FusedSDPA will be disabled.')
        return False

    from vllm_gaudi.extension.bucketing.common import get_bucketing_manager
    bucketing_manager = get_bucketing_manager()
    assert bucketing_manager is not None and bucketing_manager.initialized, 'Bucketing manager should be instantiated and initialized to enable FusedSDPA slicing.'

    from vllm_gaudi.extension.bucketing.padding_aware import PaddingAwareBucketingStrategy
    strategy = bucketing_manager.get_bucketing_strategy()
    assert isinstance(
        strategy,
        PaddingAwareBucketingStrategy), 'Bucketing strategy should be Padding-Aware to enable FusedSDPA slicing.'

    max_num_batched_tokens = bucketing_manager.max_num_batched_tokens
    block_size = bucketing_manager.block_size
    slice_thld_default = min(max_num_batched_tokens, 8192)
    slice_thld = int(os.getenv("VLLM_HPU_FSDPA_SLICE_SEQ_LEN_THLD", str(slice_thld_default)))
    assert slice_thld > block_size, 'Invalid FusedSDPA slice sequence length threshold, the threshold should be greater than the block size.'
    assert slice_thld >= 1024, 'The FusedSDPA slice sequence length threshold should be greater than or equal to 1024 to ensure the chunk sizes are valid for the attention kernel.'
    if slice_thld < slice_thld_default:
        logger().warning_once(
            f'The FusedSDPA slice sequence length threshold {slice_thld} is less than the default {slice_thld_default} which is not recommended.'
        )

    # defaults to half of the threshold and round up by 1024
    chunk_size_default = math.ceil(slice_thld // 2 / 1024) * 1024
    chunk_size = int(os.getenv("VLLM_HPU_FSDPA_SLICE_CHUNK_SIZE", str(chunk_size_default)))
    if chunk_size % 1024 != 0:
        chunk_size = math.ceil(chunk_size / 1024) * 1024
        logger().warning_once('Rounded up the chunk size for FusedSDPA slicing to the next multiple of 1024.')
    assert chunk_size > block_size and chunk_size <= slice_thld, 'Invalid FusedSDPA slice chunk size, the chunk size should be between the block size and the slice sequence length threshold.'

    self.slice_thld = slice_thld
    self.chunk_size = chunk_size

    # should align with the default value in PaddingAwareBucketingStrategy
    max_query_pad_default = math.ceil(max_num_batched_tokens / 4)
    max_query_pad = int(os.getenv("VLLM_PROMPT_QUERY_BUCKET_PAD_MAX", str(max_query_pad_default)))
    assert max_query_pad >= block_size, 'Invalid max query padding, the max query padding should be greater than or equal to the block size.'
    self.num_padded_query_chunks = math.ceil(max_query_pad / self.chunk_size)

    # should align with the default value in PaddingAwareBucketingStrategy
    max_ctx_pad_default = math.ceil(max_num_batched_tokens / block_size)
    max_ctx_pad = int(os.getenv("VLLM_PROMPT_CTX_BUCKET_PAD_MAX", str(max_ctx_pad_default)))
    self.num_padded_ctx_chunks = math.ceil(max_ctx_pad * block_size / self.chunk_size)

    import habana_frameworks.torch as ht
    is_lazy = ht.utils.internal.is_lazy()
    self._with_graph_breaks = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_GRAPH_BREAKS",
                                        str(is_lazy)).strip().lower() in ['true', 't', '1', 'yes', 'y', 'on']
    if self._with_graph_breaks and not is_lazy:
        logger().warning_once('FusedSDPA slicing graph breaks are only supported in lazy mode. '
                              'Disabling graph breaks for eager/compile mode to avoid Synapse compiler failures.')
        self._with_graph_breaks = False
    if self._with_graph_breaks:
        self._break_graph = ht.core.mark_step

    msg = (f"FusedSDPA slicing is enabled with sequence length threshold {slice_thld}, "
           f"chunk size {self.chunk_size}, num padded query chunks {self.num_padded_query_chunks}, "
           f"num padded ctx chunks {self.num_padded_ctx_chunks}, with graph breaks {self._with_graph_breaks}.")
    logger().debug_once(msg)

    return True

maybe_break_graph

maybe_break_graph()
Source code in vllm_gaudi/extension/utils.py
def maybe_break_graph(self):
    if self._with_graph_breaks:
        self._break_graph()

Softmax

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class Softmax(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, dim=None, inv_head=None):
        return torch.softmax(x, dim)

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super().__init__()

forward

forward(x, dim=None, inv_head=None)
Source code in vllm_gaudi/extension/utils.py
def forward(self, x, dim=None, inv_head=None):
    return torch.softmax(x, dim)

VLLMFP8KVCache

Bases: VLLMKVCache

Source code in vllm_gaudi/extension/utils.py
class VLLMFP8KVCache(VLLMKVCache):

    def __init__(self, input_scale=1.0):
        super().__init__()
        self.use_contiguous_pa = get_config().use_contiguous_pa
        self.input_scale = input_scale
        self.output_scale = 1.0 / self.input_scale

    def quant_input(self, input):
        return torch.ops.hpu.cast_to_fp8_v2(input, self.input_scale, False, False, torch.float8_e4m3fn)[0]

    def dequant_output(self, output):
        return torch.ops.hpu.cast_from_fp8(output, self.output_scale, torch.bfloat16)

    def forward(self, input, *args, **kwargs):
        qinput = self.quant_input(input)
        return super().forward(qinput, *args, **kwargs)

    def fetch_from_cache(self, quant_cache, blocks, permutations=None, **kwargs):
        if permutations:
            output_cache = super().fetch_from_cache(quant_cache, blocks, permutations)
            for i in range(len(output_cache)):
                output_cache[i] = self.dequant_output(output_cache[i])
            return output_cache
        output_cache = super().fetch_from_cache(quant_cache, blocks)
        return self.dequant_output(output_cache)

input_scale instance-attribute

input_scale = input_scale

output_scale instance-attribute

output_scale = 1.0 / input_scale

use_contiguous_pa instance-attribute

use_contiguous_pa = use_contiguous_pa

__init__

__init__(input_scale=1.0)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, input_scale=1.0):
    super().__init__()
    self.use_contiguous_pa = get_config().use_contiguous_pa
    self.input_scale = input_scale
    self.output_scale = 1.0 / self.input_scale

dequant_output

dequant_output(output)
Source code in vllm_gaudi/extension/utils.py
def dequant_output(self, output):
    return torch.ops.hpu.cast_from_fp8(output, self.output_scale, torch.bfloat16)

fetch_from_cache

fetch_from_cache(
    quant_cache, blocks, permutations=None, **kwargs
)
Source code in vllm_gaudi/extension/utils.py
def fetch_from_cache(self, quant_cache, blocks, permutations=None, **kwargs):
    if permutations:
        output_cache = super().fetch_from_cache(quant_cache, blocks, permutations)
        for i in range(len(output_cache)):
            output_cache[i] = self.dequant_output(output_cache[i])
        return output_cache
    output_cache = super().fetch_from_cache(quant_cache, blocks)
    return self.dequant_output(output_cache)

forward

forward(input, *args, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def forward(self, input, *args, **kwargs):
    qinput = self.quant_input(input)
    return super().forward(qinput, *args, **kwargs)

quant_input

quant_input(input)
Source code in vllm_gaudi/extension/utils.py
def quant_input(self, input):
    return torch.ops.hpu.cast_to_fp8_v2(input, self.input_scale, False, False, torch.float8_e4m3fn)[0]

VLLMKVCache

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class VLLMKVCache(torch.nn.Module):

    def __init__(self, is_v_cache: bool = False):
        super().__init__()
        self.use_contiguous_pa = get_config().use_contiguous_pa
        # is_v_cache is used in INC FP8 dynamic quantization to identify V cache
        self.is_v_cache = is_v_cache

    def forward(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, **kwargs):
        # In cross-attention kv cache forward inputs are None in decode
        # We don't want to store them in the cache in such case
        if input is not None:
            cache.index_copy_(0, slot_mapping, input)
        return cache

    def fetch_from_cache(self, cache, blocks, scales=None, **kwargs):
        if self.use_contiguous_pa:
            return cache[:blocks.size(0)]
        else:
            return cache.index_select(0, blocks)

is_v_cache instance-attribute

is_v_cache = is_v_cache

use_contiguous_pa instance-attribute

use_contiguous_pa = use_contiguous_pa

__init__

__init__(is_v_cache: bool = False)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, is_v_cache: bool = False):
    super().__init__()
    self.use_contiguous_pa = get_config().use_contiguous_pa
    # is_v_cache is used in INC FP8 dynamic quantization to identify V cache
    self.is_v_cache = is_v_cache

fetch_from_cache

fetch_from_cache(cache, blocks, scales=None, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def fetch_from_cache(self, cache, blocks, scales=None, **kwargs):
    if self.use_contiguous_pa:
        return cache[:blocks.size(0)]
    else:
        return cache.index_select(0, blocks)

forward

forward(
    input,
    cache,
    slot_mapping,
    scales=None,
    block_size=None,
    is_prompt=False,
    **kwargs,
)
Source code in vllm_gaudi/extension/utils.py
def forward(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, **kwargs):
    # In cross-attention kv cache forward inputs are None in decode
    # We don't want to store them in the cache in such case
    if input is not None:
        cache.index_copy_(0, slot_mapping, input)
    return cache

align_and_pad

align_and_pad(data, bucketing, padding_gen)
Source code in vllm_gaudi/extension/utils.py
def align_and_pad(data, bucketing, padding_gen):
    bs = len(data)
    target_bs, target_len = bucketing
    if target_bs == 1 and bs > 1:
        data = [list(itertools.chain(*data))]
    data = [pad_list(x, target_len, padding_gen) for x in data]
    padding = itertools.islice(padding_gen, target_len)
    data = pad_list(data, target_bs, itertools.tee(padding, target_bs - len(data)))
    return data

get_kv_fetch_extra_args

get_kv_fetch_extra_args(**kwargs)
Source code in vllm_gaudi/extension/utils.py
def get_kv_fetch_extra_args(**kwargs):
    if not get_config().per_token_kv_scaling_support:
        kwargs.pop('scales', None)
    return kwargs

is_fake_hpu cached

is_fake_hpu() -> bool
Source code in vllm_gaudi/extension/utils.py
@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
    return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'

pad_list

pad_list(input, target_len, val_generator)
Source code in vllm_gaudi/extension/utils.py
def pad_list(input, target_len, val_generator):
    padding = target_len - len(input)
    if padding > 0:
        input.extend(itertools.islice(val_generator, padding))
    return input

with_default

with_default(value: Optional[Any], default: Any) -> Any
Source code in vllm_gaudi/extension/utils.py
def with_default(value: Optional[Any], default: Any) -> Any:
    if value is not None:
        return value
    return default