Skip to content

vllm_gaudi.attention.backends.hpu_attn

logger module-attribute

logger = logger()

HPUAttentionBackend

Bases: AttentionBackend

Source code in vllm_gaudi/attention/backends/hpu_attn.py
class HPUAttentionBackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        raise NotImplementedError()

    @staticmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        raise NotImplementedError()

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        raise NotImplementedError()

    @staticmethod
    def get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]:
        return HPUPagedAttentionMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dsts: torch.Tensor,
    ) -> None:
        HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

    @staticmethod
    def copy_blocks(
        kv_caches: list[torch.Tensor],
        src_to_dsts: torch.Tensor,
    ) -> None:
        HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [128]

copy_blocks staticmethod

copy_blocks(
    kv_caches: list[Tensor], src_to_dsts: Tensor
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def copy_blocks(
    kv_caches: list[torch.Tensor],
    src_to_dsts: torch.Tensor,
) -> None:
    HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)

get_builder_cls staticmethod

get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_builder_cls() -> type[HPUPagedAttentionMetadataBuilder]:
    return HPUPagedAttentionMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[AttentionImpl]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
    raise NotImplementedError()

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    raise NotImplementedError()

get_name staticmethod

get_name() -> str
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_name() -> str:
    raise NotImplementedError()

get_supported_kernel_block_sizes staticmethod

get_supported_kernel_block_sizes() -> list[
    int | MultipleOf
]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
    return [128]

swap_blocks staticmethod

swap_blocks(
    src_kv_cache: Tensor,
    dst_kv_cache: Tensor,
    src_to_dsts: Tensor,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def swap_blocks(
    src_kv_cache: torch.Tensor,
    dst_kv_cache: torch.Tensor,
    src_to_dsts: torch.Tensor,
) -> None:
    HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

HPUAttentionImpl

Bases: AttentionImpl, Module

If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

Otherwise, the layout is as follows: |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding.

The prompts might have different lengths, while the generation tokens always have length 1.

Source code in vllm_gaudi/attention/backends/hpu_attn.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prefill_tokens ----------------->|
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

    Otherwise, the layout is as follows:
    |<----------------- num_decode_tokens ------------------>|
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
        use_irope: bool = False,
        sinks: Optional[torch.Tensor] = None,
    ) -> None:
        super(AttentionImpl, self).__init__()
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
        if kv_sharing_target_layer_name is not None:
            logger.info("[KV sharing] HPUAttentionImpl initialized with kv_sharing_target_layer_name: %s",
                        self.kv_sharing_target_layer_name)
        if use_irope:
            logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
                                "to global attention for long context.")
        self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
        self.kv_cache_dtype = kv_cache_dtype
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.matmul_qk = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.softmax = Softmax()
        self.matmul_av = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        self.v_cache = VLLMKVCache(is_v_cache=True) if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        HPUFusedSDPA = kernels.fsdpa()
        self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
            else ModuleFusedSDPA(HPUFusedSDPA)
        self.prefill_impl = get_config().prompt_attn_impl
        self.use_contiguous_pa = get_config().use_contiguous_pa
        self.use_merged_prefill = get_config().merged_prefill
        if alibi_slopes is not None:
            assert self.prefill_impl != 'flex_impl', \
                'Prefill with Flex Attention not supported with alibi slopes!'
            assert self.prefill_impl != 'fsdpa_impl', \
                'Prefill with FusedSDPA not supported with alibi slopes!'
            assert self.use_contiguous_pa, \
                'Non-contiguous PA not supported with alibi slopes!'

        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.sliding_window = sliding_window
        self.prompt_position_bias = None
        self.prev_attn = None
        self.alibi_slopes = None
        if alibi_slopes is not None:
            slope_tensor_dtype = torch.float32 if \
                get_config().fp32_alibi_biases else torch.bfloat16
            alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=slope_tensor_dtype)
            self.alibi_slopes = alibi_slopes_tensor

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            raise ValueError(f"Head size {head_size} is not supported by PagedAttention. "
                             f"Supported head sizes are: {supported_head_sizes}.")

        self.attn_type = attn_type
        if (self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_DECODER
                and self.attn_type != AttentionType.ENCODER_ONLY):
            raise NotImplementedError("Encoder self-attention "
                                      "is not implemented for "
                                      "HPUAttentionImpl")
        self.sinks = sinks
        if sinks is not None:
            assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
                                                 f"heads in the layer. Sinks shape: {sinks.shape}, "
                                                 f"num_heads: {num_heads}.")

        self.is_chunked_attention = False

    def _maybe_init_alibi_biases(
        self,
        max_seq_len,
        prev_attn: Optional[torch.nn.Module] = None,
    ) -> None:
        self.max_seq_len = max_seq_len
        self.prev_attn = None if prev_attn is None else prev_attn.impl
        if self.alibi_slopes is not None:
            if self.prev_attn is not None:
                self.alibi_slopes = self.prev_attn.alibi_slopes
                self.prompt_position_bias = self.prev_attn.prompt_position_bias
            else:
                # Creating the prompt_position_bias once and reusing it
                # if seq_len permits.
                self.prompt_position_bias = _make_prompt_alibi_bias(
                    alibi_slopes=self.alibi_slopes,
                    seq_len=self.max_seq_len,
                    dtype=self.alibi_slopes.dtype,
                )

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata,
        output: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        if self.attn_type == AttentionType.ENCODER_DECODER:
            return self.forward_encoder_decoder(
                query=query,
                key=key,
                value=value,
                kv_cache=kv_cache,
                attn_metadata=attn_metadata,
                k_scale=layer._k_scale_float,
                v_scale=layer._k_scale_float,
            )
        # Set return shape
        output_shape = query.shape
        if query.dim() == 2:
            if attn_metadata.seq_lens_tensor is not None:
                batch_size = attn_metadata.seq_lens_tensor.shape[0] if not self.use_merged_prefill else 1
            else:
                assert attn_metadata.block_mapping is not None, \
                    "seq_lens_tensor must be provided for attention"
                batch_size = attn_metadata.block_mapping.shape[1]
            num_tokens, hidden_size = query.shape
            seq_len = num_tokens // batch_size
            query = query.view(batch_size, seq_len, -1)
        else:
            batch_size, seq_len, hidden_size = query.shape

        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
        slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None
        key_cache = None
        value_cache = None
        k_scales = None
        v_scales = None
        if kv_cache is not None and isinstance(kv_cache, tuple):
            key_cache, value_cache, k_scales, v_scales = \
                HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)
            if key.dtype == torch.float32 and key.dtype != key_cache.dtype:
                key = key.to(key_cache.dtype)
            if value.dtype == torch.float32 and value.dtype != value_cache.dtype:
                value = value.to(value_cache.dtype)
            if query.dtype != key.dtype:
                query = query.to(key.dtype)
            if self.kv_sharing_target_layer_name is None:
                # Reshape the input keys and values and store them in the cache.
                # If kv_cache is not provided, the new key and value tensors are
                # not cached. This happens during the initial memory profiling run.
                key_cache = self.k_cache(key,
                                         key_cache,
                                         slot_mapping,
                                         scales=k_scales,
                                         block_size=attn_metadata.block_size,
                                         is_prompt=attn_metadata.is_prompt)
                value_cache = self.v_cache(value,
                                           value_cache,
                                           slot_mapping,
                                           scales=v_scales,
                                           block_size=attn_metadata.block_size,
                                           is_prompt=attn_metadata.is_prompt)

        if attn_metadata.is_prompt or seq_len > 1:
            # Prompt run.
            query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
            kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)

            attn_bias = attn_metadata.attn_bias
            position_bias = None
            # If we have alibi_slopes, incorporate them with
            if (attn_metadata.block_list is None and self.prompt_position_bias is not None
                    and self.alibi_slopes is not None):
                assert attn_bias is not None, \
                        'attn_bias must be set before calling ' \
                        'model.forward with alibi biases'
                slice_1_size = attn_bias.size(-2)
                slice_2_size = attn_bias.size(-1)
                if self.max_seq_len >= max(slice_1_size, slice_2_size):
                    # Using pre-computed prompt_position_bias subset.
                    position_bias = self.prompt_position_bias[:, :, -slice_1_size:, -slice_2_size:]

                else:
                    # For longer sequences than precomputed,
                    # recreate the bias. This is memory inefficient.
                    position_bias = _make_prompt_alibi_bias(
                        alibi_slopes=self.alibi_slopes,
                        seq_len=max(slice_1_size, slice_2_size),
                        dtype=self.alibi_slopes.dtype,
                    )

            block_list = attn_metadata.block_list if attn_metadata \
                and attn_metadata.block_list is not None else None

            common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size,
                                                     k_scales, v_scales)

            if self.sliding_window:
                if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
                    attn_bias = attn_metadata.window_attn_bias
                else:
                    attn_bias = None
                    window_size = (self.sliding_window, 0)
                    common_args['window_size'] = window_size
            if self.is_chunked_attention and \
                hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
                attn_bias = attn_metadata.chunked_attn_bias

            out = ops.prompt_attention(impl=self.prefill_impl,
                                       query=query.view(query_shape),
                                       key=key.view(kv_shape),
                                       value=value.view(kv_shape),
                                       is_causal=True,
                                       attn_bias=attn_bias,
                                       position_bias=position_bias,
                                       valid_seq_lengths=attn_metadata.seq_lens_tensor,
                                       **common_args)

            output = out.reshape(batch_size, seq_len, hidden_size)
        else:
            # Decoding run.
            if self.sliding_window and \
                attn_metadata.window_block_list is not None:
                block_list = attn_metadata.window_block_list
                block_groups = attn_metadata.window_block_groups
                block_mapping = attn_metadata.window_block_mapping
                attn_bias = attn_metadata.window_attn_bias
            elif self.is_chunked_attention and \
                attn_metadata.chunked_block_list is not None:
                block_list = attn_metadata.chunked_block_list
                block_groups = attn_metadata.chunked_block_groups
                block_mapping = attn_metadata.chunked_block_mapping
                attn_bias = attn_metadata.chunked_attn_bias
            else:
                block_list = attn_metadata.block_list
                block_groups = attn_metadata.block_groups
                block_mapping = attn_metadata.block_mapping
                attn_bias = attn_metadata.attn_bias

            self.position_bias = None
            alibi_blocks = getattr(attn_metadata, 'alibi_blocks', None)
            if self.alibi_slopes is not None and alibi_blocks is not None:
                if self.prev_attn is not None:
                    self.position_bias = self.prev_attn.position_bias
                else:
                    # For decoding, compute position bias using alibi_blocks.
                    self.position_bias = _make_decode_alibi_bias(
                        alibi_blocks=alibi_blocks,
                        alibi_slopes=self.alibi_slopes,
                        dtype=self.alibi_slopes.dtype,
                    )

            if key_cache is None:
                return torch.zeros(*output_shape, dtype=query.dtype, device=query.device)

            output = HPUPagedAttention.forward_decode(query=query,
                                                      block_mapping=block_mapping,
                                                      block_bias=attn_bias,
                                                      block_groups=block_groups,
                                                      position_bias=self.position_bias,
                                                      **self.common_attention_args(block_list, key_cache, value_cache,
                                                                                   attn_metadata.block_size, k_scales,
                                                                                   v_scales))

        return output.view(*output_shape)

    def common_attention_args(self,
                              block_list=None,
                              key_cache=None,
                              value_cache=None,
                              block_size=None,
                              k_scales=None,
                              v_scales=None):
        return {
            'scale': self.scale,
            'matmul_qk_op': self.matmul_qk,
            'matmul_av_op': self.matmul_av,
            'batch2block_matmul_op': self.batch2block_matmul,
            'block2batch_matmul_op': self.block2batch_matmul,
            'fsdpa_op': self.fused_scaled_dot_product_attention,
            'keys_fetch_func': self.k_cache.fetch_from_cache,
            'values_fetch_func': self.v_cache.fetch_from_cache,
            'softmax_op': self.softmax,
            'block_list': block_list,
            'key_cache': key_cache,
            'value_cache': value_cache,
            'block_size': block_size,
            "sinks": self.sinks,
            'k_scales': k_scales,
            'v_scales': v_scales,
        }

    def forward_encoder_decoder(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata,
        k_scale: float = 1.0,
        v_scale: float = 1.0,
    ) -> torch.Tensor:
        """Forward pass with xFormers and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        batch_size, hidden_size = query.shape

        if attn_metadata.is_prompt:
            batch_size = attn_metadata.num_prefills
            batched_tokens, _ = query.shape
            batched_kv_tokens, _, _ = key.shape
            assert batch_size > 0, ("In prefill stage the num_prefills should be > 0")
            assert batched_tokens % batch_size == 0
            assert batched_kv_tokens % batch_size == 0
            seq_len = batched_tokens // batch_size

        query = query.unsqueeze(1)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        cross_slot_mapping = attn_metadata.cross_slot_mapping.flatten(
        ) if attn_metadata.cross_slot_mapping is not None else None
        if kv_cache is not None and isinstance(kv_cache, tuple):
            key_cache, value_cache, k_scales, v_scales = \
                HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            key_cache = self.k_cache(key,
                                     key_cache,
                                     cross_slot_mapping,
                                     scales=k_scales,
                                     block_size=attn_metadata.block_size,
                                     is_prompt=attn_metadata.is_prompt)
            value_cache = self.v_cache(value,
                                       value_cache,
                                       cross_slot_mapping,
                                       scales=v_scales,
                                       block_size=attn_metadata.block_size,
                                       is_prompt=attn_metadata.is_prompt)

        if attn_metadata.is_prompt:
            # Prompt run.
            batch_size = attn_metadata.num_prefills

            query_shape = (batch_size, -1, self.num_heads, self.head_size)
            kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)
            out = ops.prompt_attention(impl=self.prefill_impl,
                                       query=query.view(query_shape),
                                       key=key.view(kv_shape),
                                       value=value.view(kv_shape),
                                       attn_bias=None,
                                       is_causal=False,
                                       **self.common_attention_args())
            output = out.reshape(batch_size, seq_len, hidden_size)
        else:
            # Enc/dec cross-attention KVs match encoder sequence length;
            # cross-attention utilizes special "cross" block tables
            block_list = attn_metadata.cross_block_list
            block_mapping = attn_metadata.cross_block_mapping
            block_groups = attn_metadata.cross_block_groups
            attn_bias = attn_metadata.cross_attn_bias
            # Decoding run.
            output = HPUPagedAttention.forward_decode(query=query,
                                                      block_mapping=block_mapping,
                                                      block_bias=attn_bias,
                                                      block_groups=block_groups,
                                                      position_bias=None,
                                                      **self.common_attention_args(block_list, key_cache, value_cache,
                                                                                   attn_metadata.block_size, k_scales,
                                                                                   v_scales))
        # Reshape the output tensor.
        return output.view(batch_size, -1, hidden_size)

alibi_slopes instance-attribute

alibi_slopes = None

attn_type instance-attribute

attn_type = attn_type

batch2block_matmul instance-attribute

batch2block_matmul = (
    B2BMatmul() if not enable_fp8_attn else FP8Matmul()
)

block2batch_matmul instance-attribute

block2batch_matmul = (
    B2BMatmul() if not enable_fp8_attn else FP8Matmul()
)

enable_fp8_attn instance-attribute

enable_fp8_attn = (
    kv_cache_dtype == "fp8_inc"
    and get("QUANT_CONFIG", None) is None
)

fused_scaled_dot_product_attention instance-attribute

fused_scaled_dot_product_attention = (
    None
    if HPUFusedSDPA is None
    else ModuleFusedSDPA(HPUFusedSDPA)
)

head_size instance-attribute

head_size = head_size

is_chunked_attention instance-attribute

is_chunked_attention = False

k_cache instance-attribute

k_cache = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

matmul_av instance-attribute

matmul_av = Matmul() if not enable_fp8_attn else FP8Matmul()

matmul_qk instance-attribute

matmul_qk = Matmul() if not enable_fp8_attn else FP8Matmul()

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

prefill_impl instance-attribute

prefill_impl = prompt_attn_impl

prev_attn instance-attribute

prev_attn = None

prompt_position_bias instance-attribute

prompt_position_bias = None

scale instance-attribute

scale = float(scale)

sinks instance-attribute

sinks = sinks

sliding_window instance-attribute

sliding_window = sliding_window

softmax instance-attribute

softmax = Softmax()

use_contiguous_pa instance-attribute

use_contiguous_pa = use_contiguous_pa

use_merged_prefill instance-attribute

use_merged_prefill = merged_prefill

v_cache instance-attribute

v_cache = (
    VLLMKVCache(is_v_cache=True)
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
    sinks: Optional[Tensor] = None,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
    sinks: Optional[torch.Tensor] = None,
) -> None:
    super(AttentionImpl, self).__init__()
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
    if kv_sharing_target_layer_name is not None:
        logger.info("[KV sharing] HPUAttentionImpl initialized with kv_sharing_target_layer_name: %s",
                    self.kv_sharing_target_layer_name)
    if use_irope:
        logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
                            "to global attention for long context.")
    self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
    self.kv_cache_dtype = kv_cache_dtype
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.matmul_qk = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.softmax = Softmax()
    self.matmul_av = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.k_cache = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    self.v_cache = VLLMKVCache(is_v_cache=True) if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    HPUFusedSDPA = kernels.fsdpa()
    self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
        else ModuleFusedSDPA(HPUFusedSDPA)
    self.prefill_impl = get_config().prompt_attn_impl
    self.use_contiguous_pa = get_config().use_contiguous_pa
    self.use_merged_prefill = get_config().merged_prefill
    if alibi_slopes is not None:
        assert self.prefill_impl != 'flex_impl', \
            'Prefill with Flex Attention not supported with alibi slopes!'
        assert self.prefill_impl != 'fsdpa_impl', \
            'Prefill with FusedSDPA not supported with alibi slopes!'
        assert self.use_contiguous_pa, \
            'Non-contiguous PA not supported with alibi slopes!'

    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.sliding_window = sliding_window
    self.prompt_position_bias = None
    self.prev_attn = None
    self.alibi_slopes = None
    if alibi_slopes is not None:
        slope_tensor_dtype = torch.float32 if \
            get_config().fp32_alibi_biases else torch.bfloat16
        alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=slope_tensor_dtype)
        self.alibi_slopes = alibi_slopes_tensor

    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        raise ValueError(f"Head size {head_size} is not supported by PagedAttention. "
                         f"Supported head sizes are: {supported_head_sizes}.")

    self.attn_type = attn_type
    if (self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_DECODER
            and self.attn_type != AttentionType.ENCODER_ONLY):
        raise NotImplementedError("Encoder self-attention "
                                  "is not implemented for "
                                  "HPUAttentionImpl")
    self.sinks = sinks
    if sinks is not None:
        assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
                                             f"heads in the layer. Sinks shape: {sinks.shape}, "
                                             f"num_heads: {num_heads}.")

    self.is_chunked_attention = False

_maybe_init_alibi_biases

_maybe_init_alibi_biases(
    max_seq_len, prev_attn: Optional[Module] = None
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _maybe_init_alibi_biases(
    self,
    max_seq_len,
    prev_attn: Optional[torch.nn.Module] = None,
) -> None:
    self.max_seq_len = max_seq_len
    self.prev_attn = None if prev_attn is None else prev_attn.impl
    if self.alibi_slopes is not None:
        if self.prev_attn is not None:
            self.alibi_slopes = self.prev_attn.alibi_slopes
            self.prompt_position_bias = self.prev_attn.prompt_position_bias
        else:
            # Creating the prompt_position_bias once and reusing it
            # if seq_len permits.
            self.prompt_position_bias = _make_prompt_alibi_bias(
                alibi_slopes=self.alibi_slopes,
                seq_len=self.max_seq_len,
                dtype=self.alibi_slopes.dtype,
            )

common_attention_args

common_attention_args(
    block_list=None,
    key_cache=None,
    value_cache=None,
    block_size=None,
    k_scales=None,
    v_scales=None,
)
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def common_attention_args(self,
                          block_list=None,
                          key_cache=None,
                          value_cache=None,
                          block_size=None,
                          k_scales=None,
                          v_scales=None):
    return {
        'scale': self.scale,
        'matmul_qk_op': self.matmul_qk,
        'matmul_av_op': self.matmul_av,
        'batch2block_matmul_op': self.batch2block_matmul,
        'block2batch_matmul_op': self.block2batch_matmul,
        'fsdpa_op': self.fused_scaled_dot_product_attention,
        'keys_fetch_func': self.k_cache.fetch_from_cache,
        'values_fetch_func': self.v_cache.fetch_from_cache,
        'softmax_op': self.softmax,
        'block_list': block_list,
        'key_cache': key_cache,
        'value_cache': value_cache,
        'block_size': block_size,
        "sinks": self.sinks,
        'k_scales': k_scales,
        'v_scales': v_scales,
    }

forward

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[Tensor] = None,
) -> Tensor

Forward pass with PagedAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads * head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
attn_metadata HPUAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with PagedAttention.

    Args:
        query: shape = [num_tokens, num_heads * head_size]
        key: shape = [num_tokens, num_kv_heads * head_size]
        value: shape = [num_tokens, num_kv_heads * head_size]
        kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    if self.attn_type == AttentionType.ENCODER_DECODER:
        return self.forward_encoder_decoder(
            query=query,
            key=key,
            value=value,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
            k_scale=layer._k_scale_float,
            v_scale=layer._k_scale_float,
        )
    # Set return shape
    output_shape = query.shape
    if query.dim() == 2:
        if attn_metadata.seq_lens_tensor is not None:
            batch_size = attn_metadata.seq_lens_tensor.shape[0] if not self.use_merged_prefill else 1
        else:
            assert attn_metadata.block_mapping is not None, \
                "seq_lens_tensor must be provided for attention"
            batch_size = attn_metadata.block_mapping.shape[1]
        num_tokens, hidden_size = query.shape
        seq_len = num_tokens // batch_size
        query = query.view(batch_size, seq_len, -1)
    else:
        batch_size, seq_len, hidden_size = query.shape

    key = key.view(-1, self.num_kv_heads, self.head_size)
    value = value.view(-1, self.num_kv_heads, self.head_size)
    slot_mapping = attn_metadata.slot_mapping.flatten() if attn_metadata.slot_mapping is not None else None
    key_cache = None
    value_cache = None
    k_scales = None
    v_scales = None
    if kv_cache is not None and isinstance(kv_cache, tuple):
        key_cache, value_cache, k_scales, v_scales = \
            HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)
        if key.dtype == torch.float32 and key.dtype != key_cache.dtype:
            key = key.to(key_cache.dtype)
        if value.dtype == torch.float32 and value.dtype != value_cache.dtype:
            value = value.to(value_cache.dtype)
        if query.dtype != key.dtype:
            query = query.to(key.dtype)
        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            key_cache = self.k_cache(key,
                                     key_cache,
                                     slot_mapping,
                                     scales=k_scales,
                                     block_size=attn_metadata.block_size,
                                     is_prompt=attn_metadata.is_prompt)
            value_cache = self.v_cache(value,
                                       value_cache,
                                       slot_mapping,
                                       scales=v_scales,
                                       block_size=attn_metadata.block_size,
                                       is_prompt=attn_metadata.is_prompt)

    if attn_metadata.is_prompt or seq_len > 1:
        # Prompt run.
        query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
        kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)

        attn_bias = attn_metadata.attn_bias
        position_bias = None
        # If we have alibi_slopes, incorporate them with
        if (attn_metadata.block_list is None and self.prompt_position_bias is not None
                and self.alibi_slopes is not None):
            assert attn_bias is not None, \
                    'attn_bias must be set before calling ' \
                    'model.forward with alibi biases'
            slice_1_size = attn_bias.size(-2)
            slice_2_size = attn_bias.size(-1)
            if self.max_seq_len >= max(slice_1_size, slice_2_size):
                # Using pre-computed prompt_position_bias subset.
                position_bias = self.prompt_position_bias[:, :, -slice_1_size:, -slice_2_size:]

            else:
                # For longer sequences than precomputed,
                # recreate the bias. This is memory inefficient.
                position_bias = _make_prompt_alibi_bias(
                    alibi_slopes=self.alibi_slopes,
                    seq_len=max(slice_1_size, slice_2_size),
                    dtype=self.alibi_slopes.dtype,
                )

        block_list = attn_metadata.block_list if attn_metadata \
            and attn_metadata.block_list is not None else None

        common_args = self.common_attention_args(block_list, key_cache, value_cache, attn_metadata.block_size,
                                                 k_scales, v_scales)

        if self.sliding_window:
            if hasattr(attn_metadata, 'window_attn_bias') and attn_metadata.window_attn_bias is not None:
                attn_bias = attn_metadata.window_attn_bias
            else:
                attn_bias = None
                window_size = (self.sliding_window, 0)
                common_args['window_size'] = window_size
        if self.is_chunked_attention and \
            hasattr(attn_metadata, 'chunked_attn_bias') and attn_metadata.chunked_attn_bias is not None:
            attn_bias = attn_metadata.chunked_attn_bias

        out = ops.prompt_attention(impl=self.prefill_impl,
                                   query=query.view(query_shape),
                                   key=key.view(kv_shape),
                                   value=value.view(kv_shape),
                                   is_causal=True,
                                   attn_bias=attn_bias,
                                   position_bias=position_bias,
                                   valid_seq_lengths=attn_metadata.seq_lens_tensor,
                                   **common_args)

        output = out.reshape(batch_size, seq_len, hidden_size)
    else:
        # Decoding run.
        if self.sliding_window and \
            attn_metadata.window_block_list is not None:
            block_list = attn_metadata.window_block_list
            block_groups = attn_metadata.window_block_groups
            block_mapping = attn_metadata.window_block_mapping
            attn_bias = attn_metadata.window_attn_bias
        elif self.is_chunked_attention and \
            attn_metadata.chunked_block_list is not None:
            block_list = attn_metadata.chunked_block_list
            block_groups = attn_metadata.chunked_block_groups
            block_mapping = attn_metadata.chunked_block_mapping
            attn_bias = attn_metadata.chunked_attn_bias
        else:
            block_list = attn_metadata.block_list
            block_groups = attn_metadata.block_groups
            block_mapping = attn_metadata.block_mapping
            attn_bias = attn_metadata.attn_bias

        self.position_bias = None
        alibi_blocks = getattr(attn_metadata, 'alibi_blocks', None)
        if self.alibi_slopes is not None and alibi_blocks is not None:
            if self.prev_attn is not None:
                self.position_bias = self.prev_attn.position_bias
            else:
                # For decoding, compute position bias using alibi_blocks.
                self.position_bias = _make_decode_alibi_bias(
                    alibi_blocks=alibi_blocks,
                    alibi_slopes=self.alibi_slopes,
                    dtype=self.alibi_slopes.dtype,
                )

        if key_cache is None:
            return torch.zeros(*output_shape, dtype=query.dtype, device=query.device)

        output = HPUPagedAttention.forward_decode(query=query,
                                                  block_mapping=block_mapping,
                                                  block_bias=attn_bias,
                                                  block_groups=block_groups,
                                                  position_bias=self.position_bias,
                                                  **self.common_attention_args(block_list, key_cache, value_cache,
                                                                               attn_metadata.block_size, k_scales,
                                                                               v_scales))

    return output.view(*output_shape)

forward_encoder_decoder

forward_encoder_decoder(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
    k_scale: float = 1.0,
    v_scale: float = 1.0,
) -> Tensor

Forward pass with xFormers and PagedAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads * head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
attn_metadata HPUAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward_encoder_decoder(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: HPUAttentionMetadata,
    k_scale: float = 1.0,
    v_scale: float = 1.0,
) -> torch.Tensor:
    """Forward pass with xFormers and PagedAttention.

    Args:
        query: shape = [num_tokens, num_heads * head_size]
        key: shape = [num_tokens, num_kv_heads * head_size]
        value: shape = [num_tokens, num_kv_heads * head_size]
        kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    batch_size, hidden_size = query.shape

    if attn_metadata.is_prompt:
        batch_size = attn_metadata.num_prefills
        batched_tokens, _ = query.shape
        batched_kv_tokens, _, _ = key.shape
        assert batch_size > 0, ("In prefill stage the num_prefills should be > 0")
        assert batched_tokens % batch_size == 0
        assert batched_kv_tokens % batch_size == 0
        seq_len = batched_tokens // batch_size

    query = query.unsqueeze(1)
    if key is not None:
        assert value is not None
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
    else:
        assert value is None

    cross_slot_mapping = attn_metadata.cross_slot_mapping.flatten(
    ) if attn_metadata.cross_slot_mapping is not None else None
    if kv_cache is not None and isinstance(kv_cache, tuple):
        key_cache, value_cache, k_scales, v_scales = \
            HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

        # Reshape the input keys and values and store them in the cache.
        # If kv_cache is not provided, the new key and value tensors are
        # not cached. This happens during the initial memory profiling run.
        key_cache = self.k_cache(key,
                                 key_cache,
                                 cross_slot_mapping,
                                 scales=k_scales,
                                 block_size=attn_metadata.block_size,
                                 is_prompt=attn_metadata.is_prompt)
        value_cache = self.v_cache(value,
                                   value_cache,
                                   cross_slot_mapping,
                                   scales=v_scales,
                                   block_size=attn_metadata.block_size,
                                   is_prompt=attn_metadata.is_prompt)

    if attn_metadata.is_prompt:
        # Prompt run.
        batch_size = attn_metadata.num_prefills

        query_shape = (batch_size, -1, self.num_heads, self.head_size)
        kv_shape = (batch_size, -1, self.num_kv_heads, self.head_size)
        out = ops.prompt_attention(impl=self.prefill_impl,
                                   query=query.view(query_shape),
                                   key=key.view(kv_shape),
                                   value=value.view(kv_shape),
                                   attn_bias=None,
                                   is_causal=False,
                                   **self.common_attention_args())
        output = out.reshape(batch_size, seq_len, hidden_size)
    else:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        block_list = attn_metadata.cross_block_list
        block_mapping = attn_metadata.cross_block_mapping
        block_groups = attn_metadata.cross_block_groups
        attn_bias = attn_metadata.cross_attn_bias
        # Decoding run.
        output = HPUPagedAttention.forward_decode(query=query,
                                                  block_mapping=block_mapping,
                                                  block_bias=attn_bias,
                                                  block_groups=block_groups,
                                                  position_bias=None,
                                                  **self.common_attention_args(block_list, key_cache, value_cache,
                                                                               attn_metadata.block_size, k_scales,
                                                                               v_scales))
    # Reshape the output tensor.
    return output.view(batch_size, -1, hidden_size)

HPUAttentionMetadata dataclass

Bases: HPUPagedAttentionMetadata, AttentionMetadata

Metadata for HPUAttentionbackend.

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
    """Metadata for HPUAttentionbackend."""
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
    block_size: int
    prep_initial_states: bool
    slot_mapping: torch.Tensor
    attn_bias: Optional[torch.Tensor]
    seq_lens_tensor: Optional[torch.Tensor]
    context_lens_tensor: Optional[torch.Tensor]
    input_positions: torch.Tensor
    seq_lens: Optional[list[int]] = None
    encoder_seq_lens: Optional[list[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
    max_encoder_seq_len: Optional[int] = None
    cross_block_list: Optional[torch.Tensor] = None
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_mapping: Optional[torch.Tensor] = None
    cross_block_groups: Optional[torch.Tensor] = None
    cross_block_usage: Optional[torch.Tensor] = None
    cross_attn_bias: Optional[torch.Tensor] = None
    window_block_list: Optional[torch.Tensor] = None
    window_slot_mapping: Optional[torch.Tensor] = None
    window_block_mapping: Optional[torch.Tensor] = None
    window_block_groups: Optional[torch.Tensor] = None
    window_block_usage: Optional[torch.Tensor] = None
    window_attn_bias: Optional[torch.Tensor] = None
    chunked_slot_mapping: Optional[torch.Tensor] = None
    chunked_attn_bias: Optional[torch.Tensor] = None
    chunked_block_mapping: Optional[torch.Tensor] = None
    chunked_block_list: Optional[torch.Tensor] = None
    chunked_block_groups: Optional[torch.Tensor] = None
    chunked_block_usage: Optional[torch.Tensor] = None
    has_initial_states_p: Optional[torch.Tensor] = None
    last_chunk_indices_p: Optional[torch.Tensor] = None
    load_indices_tensor: Optional[torch.Tensor] = None  # shape: [batch,]
    store_indices_tensor: Optional[torch.Tensor] = None  # shape: [batch,]

attn_bias instance-attribute

attn_bias: Optional[Tensor]

block_size instance-attribute

block_size: int

chunked_attn_bias class-attribute instance-attribute

chunked_attn_bias: Optional[Tensor] = None

chunked_block_groups class-attribute instance-attribute

chunked_block_groups: Optional[Tensor] = None

chunked_block_list class-attribute instance-attribute

chunked_block_list: Optional[Tensor] = None

chunked_block_mapping class-attribute instance-attribute

chunked_block_mapping: Optional[Tensor] = None

chunked_block_usage class-attribute instance-attribute

chunked_block_usage: Optional[Tensor] = None

chunked_slot_mapping class-attribute instance-attribute

chunked_slot_mapping: Optional[Tensor] = None

context_lens_tensor instance-attribute

context_lens_tensor: Optional[Tensor]

cross_attn_bias class-attribute instance-attribute

cross_attn_bias: Optional[Tensor] = None

cross_block_groups class-attribute instance-attribute

cross_block_groups: Optional[Tensor] = None

cross_block_list class-attribute instance-attribute

cross_block_list: Optional[Tensor] = None

cross_block_mapping class-attribute instance-attribute

cross_block_mapping: Optional[Tensor] = None

cross_block_usage class-attribute instance-attribute

cross_block_usage: Optional[Tensor] = None

cross_slot_mapping class-attribute instance-attribute

cross_slot_mapping: Optional[Tensor] = None

encoder_seq_lens class-attribute instance-attribute

encoder_seq_lens: Optional[list[int]] = None

encoder_seq_lens_tensor class-attribute instance-attribute

encoder_seq_lens_tensor: Optional[Tensor] = None

has_initial_states_p class-attribute instance-attribute

has_initial_states_p: Optional[Tensor] = None

input_positions instance-attribute

input_positions: Tensor

is_prompt instance-attribute

is_prompt: bool

last_chunk_indices_p class-attribute instance-attribute

last_chunk_indices_p: Optional[Tensor] = None

load_indices_tensor class-attribute instance-attribute

load_indices_tensor: Optional[Tensor] = None

max_encoder_seq_len class-attribute instance-attribute

max_encoder_seq_len: Optional[int] = None

prep_initial_states instance-attribute

prep_initial_states: bool

seq_lens class-attribute instance-attribute

seq_lens: Optional[list[int]] = None

seq_lens_tensor instance-attribute

seq_lens_tensor: Optional[Tensor]

slot_mapping instance-attribute

slot_mapping: Tensor

store_indices_tensor class-attribute instance-attribute

store_indices_tensor: Optional[Tensor] = None

window_attn_bias class-attribute instance-attribute

window_attn_bias: Optional[Tensor] = None

window_block_groups class-attribute instance-attribute

window_block_groups: Optional[Tensor] = None

window_block_list class-attribute instance-attribute

window_block_list: Optional[Tensor] = None

window_block_mapping class-attribute instance-attribute

window_block_mapping: Optional[Tensor] = None

window_block_usage class-attribute instance-attribute

window_block_usage: Optional[Tensor] = None

window_slot_mapping class-attribute instance-attribute

window_slot_mapping: Optional[Tensor] = None

__init__

__init__(
    block_list: Optional[Tensor],
    block_mapping: Optional[Tensor],
    block_usage: Optional[Tensor],
    block_groups: Optional[Tensor],
    alibi_blocks: Optional[Tensor],
    is_prompt: bool,
    block_size: int,
    prep_initial_states: bool,
    slot_mapping: Tensor,
    attn_bias: Optional[Tensor],
    seq_lens_tensor: Optional[Tensor],
    context_lens_tensor: Optional[Tensor],
    input_positions: Tensor,
    seq_lens: Optional[list[int]] = None,
    encoder_seq_lens: Optional[list[int]] = None,
    encoder_seq_lens_tensor: Optional[Tensor] = None,
    max_encoder_seq_len: Optional[int] = None,
    cross_block_list: Optional[Tensor] = None,
    cross_slot_mapping: Optional[Tensor] = None,
    cross_block_mapping: Optional[Tensor] = None,
    cross_block_groups: Optional[Tensor] = None,
    cross_block_usage: Optional[Tensor] = None,
    cross_attn_bias: Optional[Tensor] = None,
    window_block_list: Optional[Tensor] = None,
    window_slot_mapping: Optional[Tensor] = None,
    window_block_mapping: Optional[Tensor] = None,
    window_block_groups: Optional[Tensor] = None,
    window_block_usage: Optional[Tensor] = None,
    window_attn_bias: Optional[Tensor] = None,
    chunked_slot_mapping: Optional[Tensor] = None,
    chunked_attn_bias: Optional[Tensor] = None,
    chunked_block_mapping: Optional[Tensor] = None,
    chunked_block_list: Optional[Tensor] = None,
    chunked_block_groups: Optional[Tensor] = None,
    chunked_block_usage: Optional[Tensor] = None,
    has_initial_states_p: Optional[Tensor] = None,
    last_chunk_indices_p: Optional[Tensor] = None,
    load_indices_tensor: Optional[Tensor] = None,
    store_indices_tensor: Optional[Tensor] = None,
) -> None

HPUMLAAttentionBackend

Bases: HPUAttentionBackend

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@register_backend(AttentionBackendEnum.CUSTOM, "HPU_MLA")
class HPUMLAAttentionBackend(HPUAttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "CUSTOM"

    @staticmethod
    def get_impl_cls() -> type["AttentionImpl"]:
        return HPUMLAImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return HPUMLAMetadata

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return (num_blocks * block_size, head_size)

get_impl_cls staticmethod

get_impl_cls() -> type[AttentionImpl]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
    return HPUMLAImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return (num_blocks * block_size, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return HPUMLAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm_gaudi/attention/backends/hpu_attn.py
@staticmethod
def get_name() -> str:
    return "CUSTOM"

HPUMLAImpl

Bases: MLACommonImpl[HPUAttentionMetadata], Module

Source code in vllm_gaudi/attention/backends/hpu_attn.py
class HPUMLAImpl(MLACommonImpl[HPUAttentionMetadata], torch.nn.Module):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
        sinks: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> None:
        torch.nn.Module.__init__(self)

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.kv_b_proj = kv_b_proj

        # NOTE(kzawora): restore this once https://github.com/vllm-project/vllm/pull/25385 is merged
        #MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window,
        #                       kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **kwargs)

        self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
        self.matmul_qk = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.softmax = Softmax()
        self.matmul_av = Matmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \
            else FP8Matmul()
        self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \
            else VLLMFP8KVCache()
        HPUFusedSDPA = kernels.fsdpa()
        self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
            else ModuleFusedSDPA(HPUFusedSDPA)

        try:
            from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa
            if self.enable_fp8_attn:
                self.fused_scaled_dot_product_attention = ModuleFP8FusedSDPA(fp8_fused_sdpa)
        except ImportError:
            pass

        self.use_merged_prefill = get_config().merged_prefill
        self.prefill_impl = get_config().prompt_attn_impl
        assert self.prefill_impl != 'fsdpa_impl' or alibi_slopes is None, \
            'Prefill with FusedSDPA not supported with alibi slopes!'
        self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
        # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
        self.is_aiter_triton_fp4_bmm_enabled = (rocm_aiter_ops.is_fp4bmm_enabled()
                                                and self.kv_b_proj.weight.dtype == torch.bfloat16)

        unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
        if any(unsupported_features):
            raise NotImplementedError("HPUMLAImpl does not support one of the following: "
                                      "alibi_slopes, sliding_window, "
                                      "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TritonMLAImpl")
        self.sinks = sinks
        if sinks is not None:
            assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
                                                 f"heads in the layer. Sinks shape: {sinks.shape}, "
                                                 f"num_heads: {num_heads}.")

    def forward_mha(  # type: ignore
            self, q: torch.Tensor, latent_vec_k: torch.Tensor, k_cache: torch.Tensor,
            attn_metadata: HPUAttentionMetadata) -> torch.Tensor:

        ##### get prefix cache #####
        if attn_metadata.block_list is not None:
            current = latent_vec_k
            # Patch for vllm-gaudi kv_cache tuple format.
            if isinstance(k_cache, tuple):
                k_cache = k_cache[0]  # Use only key_cache for MLA
            past = self.latent_cache_k.fetch_from_cache(k_cache.unflatten(0, (-1, attn_metadata.block_size)),
                                                        attn_metadata.block_list)
            past = past.view(-1, past.shape[-1])
            current = torch.concat((past, current), dim=0)
            latent_vec_k = current
        # =========================== #

        k_c_normed, k_pe = latent_vec_k.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

        kv_nope = self.kv_b_proj(k_c_normed)[0]\
            .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

        if not self.use_merged_prefill:
            assert attn_metadata.seq_lens_tensor is not None, \
                "seq_lens_tensor must be provided for prefill attention"
            batch_size = attn_metadata.seq_lens_tensor.shape[0]
        else:
            batch_size = 1
        q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
        k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
        v = v.view(batch_size, -1, self.num_heads, self.v_head_dim)

        to_pad = self.qk_head_dim - self.v_head_dim
        if to_pad > 0:
            v_padding = torch.zeros(*v.shape[:-1], q.shape[-1] - v.shape[-1], device=v.device, dtype=v.dtype)
            v_padded = torch.cat((v, v_padding), dim=-1)
        else:
            v_padded = v

        output = ops.prompt_attention(impl=self.prefill_impl,
                                      query=q,
                                      key=k,
                                      value=v_padded,
                                      is_causal=True,
                                      attn_bias=attn_metadata.attn_bias,
                                      position_bias=None,
                                      valid_seq_lengths=attn_metadata.seq_lens_tensor,
                                      scale=self.scale,
                                      matmul_qk_op=self.matmul_qk,
                                      softmax_op=self.softmax,
                                      matmul_av_op=self.matmul_av,
                                      keys_fetch_func=self.latent_cache_k.fetch_from_cache,
                                      values_fetch_func=None,
                                      fsdpa_op=self.fused_scaled_dot_product_attention)
        # remove padding
        output = output.view(batch_size, -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]

        return output.reshape(-1, self.num_heads * v.shape[-1])

    def forward_mqa(  # type: ignore
            self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor,
            attn_metadata: HPUAttentionMetadata) -> torch.Tensor:
        if k_cache is not None and isinstance(k_cache, tuple):
            key_cache, value_cache, k_scales, v_scales = \
                HPUPagedAttention.split_kv_cache(k_cache, self.num_kv_heads, self.head_size)
        if isinstance(k_cache, tuple):
            k_cache = k_cache[0]  # Use only key_cache for MLA
        query = torch.cat([q_nope, q_pe], dim=-1)
        key_cache = k_cache.unsqueeze(1) if k_cache is not None else None
        value_cache = None
        output = HPUPagedAttention.forward_decode(query=query,
                                                  key_cache=key_cache,
                                                  value_cache=value_cache,
                                                  block_list=attn_metadata.block_list,
                                                  block_mapping=attn_metadata.block_mapping,
                                                  block_bias=attn_metadata.attn_bias,
                                                  block_groups=attn_metadata.block_groups,
                                                  block_size=attn_metadata.block_size,
                                                  scale=self.scale,
                                                  matmul_qk_op=self.matmul_qk,
                                                  matmul_av_op=self.matmul_av,
                                                  batch2block_matmul_op=self.batch2block_matmul,
                                                  block2batch_matmul_op=self.block2batch_matmul,
                                                  keys_fetch_func=self.latent_cache_k.fetch_from_cache,
                                                  values_fetch_func=None,
                                                  kv_lora_rank=self.kv_lora_rank)
        return output

    # NOTE(Xinyu): Make the loaded weight contiguous to avoid the transpose
    # during each graph execution
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        super().process_weights_after_loading(act_dtype)
        # W_UV and W_UK_T are plain tensor attributes (not nn.Parameter or
        # register_buffer), so model.to('hpu') won't move them.  When INC
        # CPU-first loading is active the source weights live on CPU, making
        # these derived tensors CPU-resident too — which then causes a device
        # mismatch at the bmm calls in forward.  Explicitly place on HPU.
        self.W_UV: torch.Tensor = self.W_UV.contiguous().to("hpu")
        self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous().to("hpu")

    # NOTE(Chendi): PR25184 using output buffer as default, which can't be used in HPU Graph,
    # so we override and always return a new tensor
    def _v_up_proj(self, x):
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
        x = torch.bmm(x, self.W_UV)
        # Convert from (N, B, V) to (B, N * V)
        x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
        return x

batch2block_matmul instance-attribute

batch2block_matmul = (
    B2BMatmul() if not enable_fp8_attn else FP8Matmul()
)

block2batch_matmul instance-attribute

block2batch_matmul = (
    B2BMatmul() if not enable_fp8_attn else FP8Matmul()
)

enable_fp8_attn instance-attribute

enable_fp8_attn = (
    kv_cache_dtype == "fp8_inc"
    and get("QUANT_CONFIG", None) is None
)

fused_scaled_dot_product_attention instance-attribute

fused_scaled_dot_product_attention = (
    None
    if HPUFusedSDPA is None
    else ModuleFusedSDPA(HPUFusedSDPA)
)

head_size instance-attribute

head_size = head_size

is_aiter_triton_fp4_bmm_enabled instance-attribute

is_aiter_triton_fp4_bmm_enabled = (
    is_fp4bmm_enabled() and dtype == bfloat16
)

is_aiter_triton_fp8_bmm_enabled instance-attribute

is_aiter_triton_fp8_bmm_enabled = is_fp8bmm_enabled()

kv_b_proj instance-attribute

kv_b_proj = kv_b_proj

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

latent_cache_k instance-attribute

latent_cache_k = (
    VLLMKVCache()
    if not enable_fp8_attn
    else VLLMFP8KVCache()
)

matmul_av instance-attribute

matmul_av = Matmul() if not enable_fp8_attn else FP8Matmul()

matmul_qk instance-attribute

matmul_qk = Matmul() if not enable_fp8_attn else FP8Matmul()

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

prefill_impl instance-attribute

prefill_impl = prompt_attn_impl

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

qk_head_dim instance-attribute

qk_head_dim = qk_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

scale instance-attribute

scale = float(scale)

sinks instance-attribute

sinks = sinks

softmax instance-attribute

softmax = Softmax()

use_merged_prefill instance-attribute

use_merged_prefill = merged_prefill

v_head_dim instance-attribute

v_head_dim = v_head_dim

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    sinks: Optional[Tensor] = None,
    **kwargs,
) -> None
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    # MLA Specific Arguments
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
    sinks: Optional[torch.Tensor] = None,
    **kwargs,
) -> None:
    torch.nn.Module.__init__(self)

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    self.kv_cache_dtype = kv_cache_dtype

    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_head_dim
    self.v_head_dim = v_head_dim
    self.kv_b_proj = kv_b_proj

    # NOTE(kzawora): restore this once https://github.com/vllm-project/vllm/pull/25385 is merged
    #MLACommonImpl.__init__(self, num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window,
    #                       kv_cache_dtype, logits_soft_cap, attn_type, kv_sharing_target_layer_name, **kwargs)

    self.enable_fp8_attn = kv_cache_dtype == 'fp8_inc' and os.environ.get('QUANT_CONFIG', None) is None
    self.matmul_qk = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.softmax = Softmax()
    self.matmul_av = Matmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.batch2block_matmul = B2BMatmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.block2batch_matmul = B2BMatmul() if not self.enable_fp8_attn \
        else FP8Matmul()
    self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \
        else VLLMFP8KVCache()
    HPUFusedSDPA = kernels.fsdpa()
    self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
        else ModuleFusedSDPA(HPUFusedSDPA)

    try:
        from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa
        if self.enable_fp8_attn:
            self.fused_scaled_dot_product_attention = ModuleFP8FusedSDPA(fp8_fused_sdpa)
    except ImportError:
        pass

    self.use_merged_prefill = get_config().merged_prefill
    self.prefill_impl = get_config().prompt_attn_impl
    assert self.prefill_impl != 'fsdpa_impl' or alibi_slopes is None, \
        'Prefill with FusedSDPA not supported with alibi slopes!'
    self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
    # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
    self.is_aiter_triton_fp4_bmm_enabled = (rocm_aiter_ops.is_fp4bmm_enabled()
                                            and self.kv_b_proj.weight.dtype == torch.bfloat16)

    unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
    if any(unsupported_features):
        raise NotImplementedError("HPUMLAImpl does not support one of the following: "
                                  "alibi_slopes, sliding_window, "
                                  "logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "TritonMLAImpl")
    self.sinks = sinks
    if sinks is not None:
        assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
                                             f"heads in the layer. Sinks shape: {sinks.shape}, "
                                             f"num_heads: {num_heads}.")

_v_up_proj

_v_up_proj(x)
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _v_up_proj(self, x):
    # Convert from (B, N, L) to (N, B, L)
    x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
    # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
    x = torch.bmm(x, self.W_UV)
    # Convert from (N, B, V) to (B, N * V)
    x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
    return x

forward_mha

forward_mha(
    q: Tensor,
    latent_vec_k: Tensor,
    k_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
) -> Tensor
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward_mha(  # type: ignore
        self, q: torch.Tensor, latent_vec_k: torch.Tensor, k_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata) -> torch.Tensor:

    ##### get prefix cache #####
    if attn_metadata.block_list is not None:
        current = latent_vec_k
        # Patch for vllm-gaudi kv_cache tuple format.
        if isinstance(k_cache, tuple):
            k_cache = k_cache[0]  # Use only key_cache for MLA
        past = self.latent_cache_k.fetch_from_cache(k_cache.unflatten(0, (-1, attn_metadata.block_size)),
                                                    attn_metadata.block_list)
        past = past.view(-1, past.shape[-1])
        current = torch.concat((past, current), dim=0)
        latent_vec_k = current
    # =========================== #

    k_c_normed, k_pe = latent_vec_k.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
    k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)

    kv_nope = self.kv_b_proj(k_c_normed)[0]\
        .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
    k_nope, v = kv_nope\
        .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

    if not self.use_merged_prefill:
        assert attn_metadata.seq_lens_tensor is not None, \
            "seq_lens_tensor must be provided for prefill attention"
        batch_size = attn_metadata.seq_lens_tensor.shape[0]
    else:
        batch_size = 1
    q = q.view(batch_size, -1, self.num_heads, self.qk_head_dim)
    k = k.view(batch_size, -1, self.num_heads, self.qk_head_dim)
    v = v.view(batch_size, -1, self.num_heads, self.v_head_dim)

    to_pad = self.qk_head_dim - self.v_head_dim
    if to_pad > 0:
        v_padding = torch.zeros(*v.shape[:-1], q.shape[-1] - v.shape[-1], device=v.device, dtype=v.dtype)
        v_padded = torch.cat((v, v_padding), dim=-1)
    else:
        v_padded = v

    output = ops.prompt_attention(impl=self.prefill_impl,
                                  query=q,
                                  key=k,
                                  value=v_padded,
                                  is_causal=True,
                                  attn_bias=attn_metadata.attn_bias,
                                  position_bias=None,
                                  valid_seq_lengths=attn_metadata.seq_lens_tensor,
                                  scale=self.scale,
                                  matmul_qk_op=self.matmul_qk,
                                  softmax_op=self.softmax,
                                  matmul_av_op=self.matmul_av,
                                  keys_fetch_func=self.latent_cache_k.fetch_from_cache,
                                  values_fetch_func=None,
                                  fsdpa_op=self.fused_scaled_dot_product_attention)
    # remove padding
    output = output.view(batch_size, -1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]

    return output.reshape(-1, self.num_heads * v.shape[-1])

forward_mqa

forward_mqa(
    q_nope: Tensor,
    q_pe: Tensor,
    k_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
) -> Tensor
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def forward_mqa(  # type: ignore
        self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata) -> torch.Tensor:
    if k_cache is not None and isinstance(k_cache, tuple):
        key_cache, value_cache, k_scales, v_scales = \
            HPUPagedAttention.split_kv_cache(k_cache, self.num_kv_heads, self.head_size)
    if isinstance(k_cache, tuple):
        k_cache = k_cache[0]  # Use only key_cache for MLA
    query = torch.cat([q_nope, q_pe], dim=-1)
    key_cache = k_cache.unsqueeze(1) if k_cache is not None else None
    value_cache = None
    output = HPUPagedAttention.forward_decode(query=query,
                                              key_cache=key_cache,
                                              value_cache=value_cache,
                                              block_list=attn_metadata.block_list,
                                              block_mapping=attn_metadata.block_mapping,
                                              block_bias=attn_metadata.attn_bias,
                                              block_groups=attn_metadata.block_groups,
                                              block_size=attn_metadata.block_size,
                                              scale=self.scale,
                                              matmul_qk_op=self.matmul_qk,
                                              matmul_av_op=self.matmul_av,
                                              batch2block_matmul_op=self.batch2block_matmul,
                                              block2batch_matmul_op=self.block2batch_matmul,
                                              keys_fetch_func=self.latent_cache_k.fetch_from_cache,
                                              values_fetch_func=None,
                                              kv_lora_rank=self.kv_lora_rank)
    return output

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm_gaudi/attention/backends/hpu_attn.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    super().process_weights_after_loading(act_dtype)
    # W_UV and W_UK_T are plain tensor attributes (not nn.Parameter or
    # register_buffer), so model.to('hpu') won't move them.  When INC
    # CPU-first loading is active the source weights live on CPU, making
    # these derived tensors CPU-resident too — which then causes a device
    # mismatch at the bmm calls in forward.  Explicitly place on HPU.
    self.W_UV: torch.Tensor = self.W_UV.contiguous().to("hpu")
    self.W_UK_T: torch.Tensor = self.W_UK_T.contiguous().to("hpu")

HPUMLAMetadata dataclass

Bases: HPUAttentionMetadata, AttentionMetadata

Source code in vllm_gaudi/attention/backends/hpu_attn.py
@dataclass
class HPUMLAMetadata(HPUAttentionMetadata, AttentionMetadata):
    pass

__init__

__init__(
    block_list: Optional[Tensor],
    block_mapping: Optional[Tensor],
    block_usage: Optional[Tensor],
    block_groups: Optional[Tensor],
    alibi_blocks: Optional[Tensor],
    is_prompt: bool,
    block_size: int,
    prep_initial_states: bool,
    slot_mapping: Tensor,
    attn_bias: Optional[Tensor],
    seq_lens_tensor: Optional[Tensor],
    context_lens_tensor: Optional[Tensor],
    input_positions: Tensor,
    seq_lens: Optional[list[int]] = None,
    encoder_seq_lens: Optional[list[int]] = None,
    encoder_seq_lens_tensor: Optional[Tensor] = None,
    max_encoder_seq_len: Optional[int] = None,
    cross_block_list: Optional[Tensor] = None,
    cross_slot_mapping: Optional[Tensor] = None,
    cross_block_mapping: Optional[Tensor] = None,
    cross_block_groups: Optional[Tensor] = None,
    cross_block_usage: Optional[Tensor] = None,
    cross_attn_bias: Optional[Tensor] = None,
    window_block_list: Optional[Tensor] = None,
    window_slot_mapping: Optional[Tensor] = None,
    window_block_mapping: Optional[Tensor] = None,
    window_block_groups: Optional[Tensor] = None,
    window_block_usage: Optional[Tensor] = None,
    window_attn_bias: Optional[Tensor] = None,
    chunked_slot_mapping: Optional[Tensor] = None,
    chunked_attn_bias: Optional[Tensor] = None,
    chunked_block_mapping: Optional[Tensor] = None,
    chunked_block_list: Optional[Tensor] = None,
    chunked_block_groups: Optional[Tensor] = None,
    chunked_block_usage: Optional[Tensor] = None,
    has_initial_states_p: Optional[Tensor] = None,
    last_chunk_indices_p: Optional[Tensor] = None,
    load_indices_tensor: Optional[Tensor] = None,
    store_indices_tensor: Optional[Tensor] = None,
) -> None

_make_decode_alibi_bias

_make_decode_alibi_bias(
    alibi_blocks: Tensor, alibi_slopes: Tensor, dtype: dtype
) -> Tensor

Create the ALiBi position bias tensor for decode stage. Uses stored alibi_blocks and slopes for final scaling. Scales with number of blocks, not with batch size.

Parameters:

Name Type Description Default
alibi_blocks Tensor

shape = [num_blocks, block_size]

required
alibi_slopes Tensor

shape = [num_heads]

required
dtype dtype

torch.dtype

required

Returns:

Type Description
Tensor

A per-head bias tensor of shape [num_blocks, num_heads, block_size].

Tensor

Each row encodes position-dependent ALiBi slopes for decoding steps.

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _make_decode_alibi_bias(
    alibi_blocks: torch.Tensor,
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Create the ALiBi position bias tensor for decode stage.
    Uses stored alibi_blocks and slopes for final scaling.
    Scales with number of blocks, not with batch size.

    Args:
        alibi_blocks: shape = [num_blocks, block_size]
        alibi_slopes: shape = [num_heads]
        dtype: torch.dtype

    Returns:
        A per-head bias tensor of shape [num_blocks, num_heads, block_size].
        Each row encodes position-dependent ALiBi slopes for decoding steps.
    """
    num_heads = alibi_slopes.shape[0]
    per_head_bias = torch.empty(
        alibi_blocks.size(0),
        num_heads,
        alibi_blocks.size(-1),
        device=alibi_slopes.device,
        dtype=dtype,
    )
    # NOTE(Tanner):
    # .copy_ was not performing broadcasting of bias
    # to all 32 heads in Eager mode.
    per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
    per_head_bias.mul_(alibi_slopes[None, :, None])

    return per_head_bias

_make_prompt_alibi_bias

_make_prompt_alibi_bias(
    alibi_slopes: Tensor, seq_len: int, dtype: dtype
) -> Tensor

Create the ALiBi position bias tensor for prompt stage. This tensor is reused or tiled as needed for each forward pass. Does not scale with batch size or number of blocks.

Parameters:

Name Type Description Default
alibi_slopes Tensor

shape = [num_heads]

required
seq_len int

int

required
dtype dtype

torch.dtype

required

Returns:

Type Description
Tensor

A per-head bias tensor of shape [1, num_heads, seq_len, seq_len].

Tensor

This bias encodes positional information via ALiBi slopes.

Source code in vllm_gaudi/attention/backends/hpu_attn.py
def _make_prompt_alibi_bias(
    alibi_slopes: torch.Tensor,
    seq_len: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    Create the ALiBi position bias tensor for prompt stage.
    This tensor is reused or tiled as needed for each forward pass.
    Does not scale with batch size or number of blocks.

    Args:
        alibi_slopes: shape = [num_heads]
        seq_len: int
        dtype: torch.dtype

    Returns:
        A per-head bias tensor of shape [1, num_heads, seq_len, seq_len].
        This bias encodes positional information via ALiBi slopes.
    """
    # Create the bias matrix for positional differences
    bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
    bias = bias[None, :] - bias[:, None]  # Shape: [seq_len, seq_len]

    #padded_len = (seq_len + 7) // 8 * 8
    num_heads = alibi_slopes.shape[0]
    per_head_bias = torch.empty(
        1,
        num_heads,
        seq_len,
        seq_len,  # Directly use seq_len instead of padded_len
        device=alibi_slopes.device,
        dtype=dtype,
    )

    # Copy the bias matrix into each head
    per_head_bias[:, :] = bias

    # Scale the bias by the ALiBi slopes
    per_head_bias.mul_(alibi_slopes[:, None, None])

    return per_head_bias