vllm_gaudi.models.llama4
¶
HPU-optimized Llama4 modules using model registry override pattern.
Llama4 has 48 decoder layers with heterogeneous configurations
- NoPE layers: no rotary_emb, temperature tuning, Attention backend
- RoPE layers: has rotary_emb, no temperature tuning, ChunkedLocalAttention
This module provides
- HpuLlama4Model — overrides forward() to initialize residual as zeros instead of None, eliminating torch._dynamo type guard.
- HpuLlama4ForCausalLM — registered via ModelRegistry, applies branch-free attention patches and attention type unification in init.
- Branch-free attention patching — boolean buffer masks + torch.where eliminate Python if/else guards on nope/rotary_emb/temperature_tuning.
- Attention type unification — swaps ChunkedLocalAttention → Attention to eliminate torch._dynamo type guards across layers.
HpuLlama4ForCausalLM
¶
Bases: Llama4ForCausalLM
HPU-optimized Llama4ForCausalLM registered via ModelRegistry.
Applies branch-free attention patches, attention type unification, and swaps the inner model class to HpuLlama4Model for residual fix.
Source code in vllm_gaudi/models/llama4.py
HpuLlama4ForConditionalGeneration
¶
Bases: Llama4ForConditionalGeneration
HPU override of Llama4ForConditionalGeneration.
After upstream init creates language_model (Llama4ForCausalLM), swaps the inner model class and applies branch-free attention patches.
Source code in vllm_gaudi/models/llama4.py
HpuLlama4Model
¶
Bases: Llama4Model
Llama4Model with residual initialized as zeros instead of None.
The upstream LlamaModel.forward() sets residual = None for the first
rank, which creates a torch._dynamo type guard (None vs Tensor) that
causes recompilation between layer 0 and layers 1-47. Initializing
residual as torch.zeros_like(hidden_states) eliminates this guard.
Source code in vllm_gaudi/models/llama4.py
forward
¶
forward(
input_ids: Tensor | None,
positions: Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: Tensor | None = None,
**extra_layer_kwargs,
) -> (
Tensor
| IntermediateTensors
| tuple[Tensor, list[Tensor]]
)
Source code in vllm_gaudi/models/llama4.py
_HpuLlama4FeedForward
¶
Bases: Module
Unified wrapper for Llama4MoE and LlamaMLP feed_forward modules.
Makes all feed_forward modules the same Python type to eliminate torch._dynamo type guards when iterating across decoder layers.
Source code in vllm_gaudi/models/llama4.py
_apply_branch_free_attention
¶
Apply branch-free attention patches to all Llama4Attention layers.
Source code in vllm_gaudi/models/llama4.py
_apply_hpu_llama4_init_patches
¶
_apply_hpu_llama4_init_patches(model_root: Module) -> None
Shared init-time patches for both CausalLM and ConditionalGeneration.
Swaps the inner model class for the residual=zeros fix and applies attention type unification in compile mode. _unify_feed_forward_types is deferred to post-load via apply_hpu_llama4_post_load_patches() to avoid breaking weight loading.
Source code in vllm_gaudi/models/llama4.py
_branchfree_attention_forward
¶
Branch-free Llama4Attention forward.
All layers execute identical code. Boolean buffer masks + torch.where select RoPE'd/un-RoPE'd, norm'd/un-norm'd, and scaled/un-scaled at the data level — no Python if/else guards for torch.compile.
Source code in vllm_gaudi/models/llama4.py
_patch_attention_module
¶
Patch a single Llama4Attention module to be branch-free.
Source code in vllm_gaudi/models/llama4.py
_unify_attention_types
¶
Change ChunkedLocalAttention instances to Attention type.
Since ChunkedLocalAttention does NOT override forward(), the class swap is behaviorally identical. get_kv_cache_spec is preserved as instance method.
WARNING: After this swap, isinstance(x, ChunkedLocalAttention) returns False. Currently safe because maybe_set_chunked_attention_layers uses string matching on backend names. If upstream ever switches to isinstance checks, this will need updating. A _was_chunked_local marker is set for future detection.
Source code in vllm_gaudi/models/llama4.py
_unify_feed_forward_types
¶
_unify_feed_forward_types(layers) -> int
Wrap all feed_forward modules in a unified type.
Replaces heterogeneous Llama4MoE / LlamaMLP feed_forward attributes with _HpuLlama4FeedForward wrappers so torch._dynamo sees one type.
Source code in vllm_gaudi/models/llama4.py
apply_hpu_llama4_post_load_patches
¶
Apply patches that must run after load_weights().
_unify_feed_forward_types wraps feed_forward modules in a unified type. This must happen after weight loading because the wrapper changes named_parameters() keys (adds .inner.) and hides .experts attribute, which would break load_moe_expert_weights() in upstream Llama4Model.
Called from apply_model_specific_patches() in hpu_model_runner.py.