Skip to content

vllm_gaudi.extension.utils

B2BMatmul

Bases: Matmul

Specialized alias for batch2block and block2batch matmul operations.

This class remains functionally identical to Matmul but is used to semantically mark B2B-related matmuls. This enables the system to apply the fix that uses the B2B output measurements as the input measurements during calibration, avoiding corrupted scales from the KV‑cache.

Source code in vllm_gaudi/extension/utils.py
class B2BMatmul(Matmul):
    """Specialized alias for batch2block and block2batch matmul operations.

    This class remains functionally identical to ``Matmul`` but is used to
    semantically mark B2B-related matmuls. This enables the system to apply the
    fix that uses the B2B output measurements as the input measurements during
    calibration, avoiding corrupted scales from the KV‑cache.
    """

    def __init__(self):
        super().__init__()

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super().__init__()

FP8Matmul

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class FP8Matmul(torch.nn.Module):

    def __init__(
        self,
        scale_input=1.0,
        scale_other=1.0,
    ):
        super().__init__()
        self.scale_input = scale_input
        self.scale_other = scale_other

    def quant_input(self, x, scale):
        return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

    def matmul_fp8(self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None):
        return torch.ops.hpu.fp8_gemm_v2(
            A=x,
            trans_A=False,
            B=other,
            trans_B=False,
            D=None,
            out_dtype=out_dtype,
            A_scale_inv=scale_input_inv,
            B_scale_inv=scale_other_inv,
            bias=None,
            accumulate=False,
        )

    def forward(self, input, other, **kwargs):
        qinput = self.quant_input(input, self.scale_input)
        qother = self.quant_input(other, self.scale_other)
        output = self.matmul_fp8(
            qinput,
            qother,
            out_dtype=torch.bfloat16,
            scale_input_inv=1.0 / self.scale_input,
            scale_other_inv=1.0 / self.scale_other,
        )
        return output

scale_input instance-attribute

scale_input = scale_input

scale_other instance-attribute

scale_other = scale_other

__init__

__init__(scale_input=1.0, scale_other=1.0)
Source code in vllm_gaudi/extension/utils.py
def __init__(
    self,
    scale_input=1.0,
    scale_other=1.0,
):
    super().__init__()
    self.scale_input = scale_input
    self.scale_other = scale_other

forward

forward(input, other, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def forward(self, input, other, **kwargs):
    qinput = self.quant_input(input, self.scale_input)
    qother = self.quant_input(other, self.scale_other)
    output = self.matmul_fp8(
        qinput,
        qother,
        out_dtype=torch.bfloat16,
        scale_input_inv=1.0 / self.scale_input,
        scale_other_inv=1.0 / self.scale_other,
    )
    return output

matmul_fp8

matmul_fp8(
    x,
    other,
    out_dtype,
    scale_input_inv=None,
    scale_other_inv=None,
)
Source code in vllm_gaudi/extension/utils.py
def matmul_fp8(self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None):
    return torch.ops.hpu.fp8_gemm_v2(
        A=x,
        trans_A=False,
        B=other,
        trans_B=False,
        D=None,
        out_dtype=out_dtype,
        A_scale_inv=scale_input_inv,
        B_scale_inv=scale_other_inv,
        bias=None,
        accumulate=False,
    )

quant_input

quant_input(x, scale)
Source code in vllm_gaudi/extension/utils.py
def quant_input(self, x, scale):
    return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

Matmul

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class Matmul(torch.nn.Module):

    def __init__(self):
        super(Matmul, self).__init__()

    def forward(self, x, y, **kwargs):
        return torch.matmul(x, y, **kwargs)

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super(Matmul, self).__init__()

forward

forward(x, y, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def forward(self, x, y, **kwargs):
    return torch.matmul(x, y, **kwargs)

ModuleFP8FusedSDPA

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class ModuleFP8FusedSDPA(torch.nn.Module):

    def __init__(self, fusedSDPA):
        super().__init__()
        assert fusedSDPA is not None, f'FP8 fusedSDPA kernel is None'
        self.fp8_fused_sdpa = fusedSDPA

        # set the descale_amax and scale_amax 1.0 temporarily
        self.descale_amax = torch.tensor(1.0)
        self.scale_amax = torch.tensor(1.0)
        self.scale_q = torch.tensor(1.0)
        self.scale_k = torch.tensor(1.0)
        self.scale_v = torch.tensor(1.0)
        self.d_scale_q = torch.tensor(1.0)
        self.d_scale_k = torch.tensor(1.0)
        self.d_scale_v = torch.tensor(1.0)

    def quant_input(self, x, scale):
        return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

    def dequant_output(self, output, scale):
        return torch.ops.hpu.cast_from_fp8(output, scale, torch.bfloat16)

    def forward(
        self,
        query,
        key,
        value,
        attn_mask,
        dropout_p,
        is_causal,
        scale,
        softmax_mode,
        recompute_mode,
        valid_sequence_lengths,
        padding_side="left",
        window_size=None,
    ):

        qinput = self.quant_input(query, self.scale_q)
        kinput = self.quant_input(key, self.scale_k)
        vinput = self.quant_input(value, self.scale_v)

        results = self.fp8_fused_sdpa(
            qinput,
            kinput,
            vinput,
            attn_mask=attn_mask,
            dropout_p=dropout_p,
            is_causal=is_causal,
            scale=scale,
            softmax_mode=softmax_mode,
            d_scale_q=self.d_scale_q,
            d_scale_k=self.d_scale_k,
            d_scale_v=self.d_scale_v,
            q_scale_s=self.scale_amax,
            # q_scale_o=1 / 1.0,
            d_scale_s=self.descale_amax,
            is_amax_s=False,
            valid_seq_len=valid_sequence_lengths,
            seq_padding_type=padding_side,
        )

        output = results[0]
        return output

d_scale_k instance-attribute

d_scale_k = tensor(1.0)

d_scale_q instance-attribute

d_scale_q = tensor(1.0)

d_scale_v instance-attribute

d_scale_v = tensor(1.0)

descale_amax instance-attribute

descale_amax = tensor(1.0)

fp8_fused_sdpa instance-attribute

fp8_fused_sdpa = fusedSDPA

scale_amax instance-attribute

scale_amax = tensor(1.0)

scale_k instance-attribute

scale_k = tensor(1.0)

scale_q instance-attribute

scale_q = tensor(1.0)

scale_v instance-attribute

scale_v = tensor(1.0)

__init__

__init__(fusedSDPA)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, fusedSDPA):
    super().__init__()
    assert fusedSDPA is not None, f'FP8 fusedSDPA kernel is None'
    self.fp8_fused_sdpa = fusedSDPA

    # set the descale_amax and scale_amax 1.0 temporarily
    self.descale_amax = torch.tensor(1.0)
    self.scale_amax = torch.tensor(1.0)
    self.scale_q = torch.tensor(1.0)
    self.scale_k = torch.tensor(1.0)
    self.scale_v = torch.tensor(1.0)
    self.d_scale_q = torch.tensor(1.0)
    self.d_scale_k = torch.tensor(1.0)
    self.d_scale_v = torch.tensor(1.0)

dequant_output

dequant_output(output, scale)
Source code in vllm_gaudi/extension/utils.py
def dequant_output(self, output, scale):
    return torch.ops.hpu.cast_from_fp8(output, scale, torch.bfloat16)

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
)
Source code in vllm_gaudi/extension/utils.py
def forward(
    self,
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
):

    qinput = self.quant_input(query, self.scale_q)
    kinput = self.quant_input(key, self.scale_k)
    vinput = self.quant_input(value, self.scale_v)

    results = self.fp8_fused_sdpa(
        qinput,
        kinput,
        vinput,
        attn_mask=attn_mask,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=scale,
        softmax_mode=softmax_mode,
        d_scale_q=self.d_scale_q,
        d_scale_k=self.d_scale_k,
        d_scale_v=self.d_scale_v,
        q_scale_s=self.scale_amax,
        # q_scale_o=1 / 1.0,
        d_scale_s=self.descale_amax,
        is_amax_s=False,
        valid_seq_len=valid_sequence_lengths,
        seq_padding_type=padding_side,
    )

    output = results[0]
    return output

quant_input

quant_input(x, scale)
Source code in vllm_gaudi/extension/utils.py
def quant_input(self, x, scale):
    return torch.ops.hpu.cast_to_fp8_v2(x, scale, False, False, torch.float8_e4m3fn)[0]

ModuleFusedSDPA

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class ModuleFusedSDPA(torch.nn.Module):

    def __init__(self, fusedSDPA):
        super().__init__()
        assert fusedSDPA is not None, f'fusedSDPA kernel is None'
        self._hpu_kernel_fsdpa = fusedSDPA

    def forward(
        self,
        query,
        key,
        value,
        attn_mask,
        dropout_p,
        is_causal,
        scale,
        softmax_mode,
        recompute_mode,
        valid_sequence_lengths,
        padding_side="left",
        window_size=None,
        sinks=None,
    ):
        if window_size is not None:
            return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                                recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                                window_size, sinks)
        else:
            return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                                recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                                (-1, -1), sinks)

_hpu_kernel_fsdpa instance-attribute

_hpu_kernel_fsdpa = fusedSDPA

__init__

__init__(fusedSDPA)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, fusedSDPA):
    super().__init__()
    assert fusedSDPA is not None, f'fusedSDPA kernel is None'
    self._hpu_kernel_fsdpa = fusedSDPA

forward

forward(
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
    sinks=None,
)
Source code in vllm_gaudi/extension/utils.py
def forward(
    self,
    query,
    key,
    value,
    attn_mask,
    dropout_p,
    is_causal,
    scale,
    softmax_mode,
    recompute_mode,
    valid_sequence_lengths,
    padding_side="left",
    window_size=None,
    sinks=None,
):
    if window_size is not None:
        return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                            recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                            window_size, sinks)
    else:
        return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
                                            recompute_mode, valid_sequence_lengths, padding_side, False, False,
                                            (-1, -1), sinks)

Softmax

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class Softmax(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, dim=None, inv_head=None):
        return torch.softmax(x, dim)

__init__

__init__()
Source code in vllm_gaudi/extension/utils.py
def __init__(self):
    super().__init__()

forward

forward(x, dim=None, inv_head=None)
Source code in vllm_gaudi/extension/utils.py
def forward(self, x, dim=None, inv_head=None):
    return torch.softmax(x, dim)

VLLMFP8KVCache

Bases: VLLMKVCache

Source code in vllm_gaudi/extension/utils.py
class VLLMFP8KVCache(VLLMKVCache):

    def __init__(self, input_scale=1.0):
        super().__init__()
        self.use_contiguous_pa = get_config().use_contiguous_pa
        self.input_scale = input_scale
        self.output_scale = 1.0 / self.input_scale

    def quant_input(self, input):
        return torch.ops.hpu.cast_to_fp8_v2(input, self.input_scale, False, False, torch.float8_e4m3fn)[0]

    def dequant_output(self, output):
        return torch.ops.hpu.cast_from_fp8(output, self.output_scale, torch.bfloat16)

    def forward(self, input, *args, **kwargs):
        qinput = self.quant_input(input)
        return super().forward(qinput, *args, **kwargs)

    def fetch_from_cache(self, quant_cache, blocks, permutations=None, **kwargs):
        if permutations:
            output_cache = super().fetch_from_cache(quant_cache, blocks, permutations)
            for i in range(len(output_cache)):
                output_cache[i] = self.dequant_output(output_cache[i])
            return output_cache
        output_cache = super().fetch_from_cache(quant_cache, blocks)
        return self.dequant_output(output_cache)

input_scale instance-attribute

input_scale = input_scale

output_scale instance-attribute

output_scale = 1.0 / input_scale

use_contiguous_pa instance-attribute

use_contiguous_pa = use_contiguous_pa

__init__

__init__(input_scale=1.0)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, input_scale=1.0):
    super().__init__()
    self.use_contiguous_pa = get_config().use_contiguous_pa
    self.input_scale = input_scale
    self.output_scale = 1.0 / self.input_scale

dequant_output

dequant_output(output)
Source code in vllm_gaudi/extension/utils.py
def dequant_output(self, output):
    return torch.ops.hpu.cast_from_fp8(output, self.output_scale, torch.bfloat16)

fetch_from_cache

fetch_from_cache(
    quant_cache, blocks, permutations=None, **kwargs
)
Source code in vllm_gaudi/extension/utils.py
def fetch_from_cache(self, quant_cache, blocks, permutations=None, **kwargs):
    if permutations:
        output_cache = super().fetch_from_cache(quant_cache, blocks, permutations)
        for i in range(len(output_cache)):
            output_cache[i] = self.dequant_output(output_cache[i])
        return output_cache
    output_cache = super().fetch_from_cache(quant_cache, blocks)
    return self.dequant_output(output_cache)

forward

forward(input, *args, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def forward(self, input, *args, **kwargs):
    qinput = self.quant_input(input)
    return super().forward(qinput, *args, **kwargs)

quant_input

quant_input(input)
Source code in vllm_gaudi/extension/utils.py
def quant_input(self, input):
    return torch.ops.hpu.cast_to_fp8_v2(input, self.input_scale, False, False, torch.float8_e4m3fn)[0]

VLLMKVCache

Bases: Module

Source code in vllm_gaudi/extension/utils.py
class VLLMKVCache(torch.nn.Module):

    def __init__(self, is_v_cache: bool = False):
        super().__init__()
        self.use_contiguous_pa = get_config().use_contiguous_pa
        # is_v_cache is used in INC FP8 dynamic quantization to identify V cache
        self.is_v_cache = is_v_cache

    def forward(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, **kwargs):
        # In cross-attention kv cache forward inputs are None in decode
        # We don't want to store them in the cache in such case
        if input is not None:
            cache.index_copy_(0, slot_mapping, input)
        return cache

    def fetch_from_cache(self, cache, blocks, scales=None, **kwargs):
        if self.use_contiguous_pa:
            return cache[:blocks.size(0)]
        else:
            return cache.index_select(0, blocks)

is_v_cache instance-attribute

is_v_cache = is_v_cache

use_contiguous_pa instance-attribute

use_contiguous_pa = use_contiguous_pa

__init__

__init__(is_v_cache: bool = False)
Source code in vllm_gaudi/extension/utils.py
def __init__(self, is_v_cache: bool = False):
    super().__init__()
    self.use_contiguous_pa = get_config().use_contiguous_pa
    # is_v_cache is used in INC FP8 dynamic quantization to identify V cache
    self.is_v_cache = is_v_cache

fetch_from_cache

fetch_from_cache(cache, blocks, scales=None, **kwargs)
Source code in vllm_gaudi/extension/utils.py
def fetch_from_cache(self, cache, blocks, scales=None, **kwargs):
    if self.use_contiguous_pa:
        return cache[:blocks.size(0)]
    else:
        return cache.index_select(0, blocks)

forward

forward(
    input,
    cache,
    slot_mapping,
    scales=None,
    block_size=None,
    is_prompt=False,
    **kwargs,
)
Source code in vllm_gaudi/extension/utils.py
def forward(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, **kwargs):
    # In cross-attention kv cache forward inputs are None in decode
    # We don't want to store them in the cache in such case
    if input is not None:
        cache.index_copy_(0, slot_mapping, input)
    return cache

align_and_pad

align_and_pad(data, bucketing, padding_gen)
Source code in vllm_gaudi/extension/utils.py
def align_and_pad(data, bucketing, padding_gen):
    bs = len(data)
    target_bs, target_len = bucketing
    if target_bs == 1 and bs > 1:
        data = [list(itertools.chain(*data))]
    data = [pad_list(x, target_len, padding_gen) for x in data]
    padding = itertools.islice(padding_gen, target_len)
    data = pad_list(data, target_bs, itertools.tee(padding, target_bs - len(data)))
    return data

get_kv_fetch_extra_args

get_kv_fetch_extra_args(**kwargs)
Source code in vllm_gaudi/extension/utils.py
def get_kv_fetch_extra_args(**kwargs):
    if not get_config().per_token_kv_scaling_support:
        kwargs.pop('scales', None)
    return kwargs

is_fake_hpu cached

is_fake_hpu() -> bool
Source code in vllm_gaudi/extension/utils.py
@lru_cache(maxsize=None)
def is_fake_hpu() -> bool:
    return os.environ.get('VLLM_USE_FAKE_HPU', '0') != '0'

pad_list

pad_list(input, target_len, val_generator)
Source code in vllm_gaudi/extension/utils.py
def pad_list(input, target_len, val_generator):
    padding = target_len - len(input)
    if padding > 0:
        input.extend(itertools.islice(val_generator, padding))
    return input

with_default

with_default(value: Optional[Any], default: Any) -> Any
Source code in vllm_gaudi/extension/utils.py
def with_default(value: Optional[Any], default: Any) -> Any:
    if value is not None:
        return value
    return default