@MambaMixer2.register_oot
class HPUMambaMixer2(MambaMixer2):
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super(MambaMixer2, self).__init__()
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0, ("Tensor parallel world size must divide num heads.")
assert (n_groups %
self.tp_size) == 0 or n_groups == 1, ("If tensor parallel world size does not divide num_groups, "
"then num_groups must equal 1.")
assert n_groups % self.tp_size == 0
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
self.activation = activation
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
self.n_groups = n_groups
self.num_spec = get_current_vllm_config().num_speculative_tokens
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
self.register_buffer("conv_weights", conv_weights, persistent=False)
# - these are TPed by heads to reduce the size of the
# temporal shape
self.A = nn.Parameter(torch.empty(
divide(num_heads, self.tp_size),
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = Mixer2RMSNormGated(intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps)
# - get hidden_states, B and C after depthwise convolution.
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
# Pre-compute sizes for forward pass
self.tped_intermediate_size = self.intermediate_size // self.tp_size
self.tped_conv_size = self.conv_dim // self.tp_size
self.tped_dt_size = self.num_heads // self.tp_size
self._split_weights_ready = False
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.tped_intermediate_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
def forward(
self,
hidden_states: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# 1. Split in_proj into two GEMMs for TPC/MME pipelining.
# GEMM 1 (states: x,B,C,dt) is dispatched to the MME first;
# GEMM 2 (gate) is dispatched second. The Gaudi runtime can
# overlap GEMM 2 on the MME with conv+SSM TPC work that
# depends only on GEMM 1.
states_proj = F.linear(hidden_states, self._states_weight, self._states_bias)
gate = F.linear(hidden_states, self._gate_weight, self._gate_bias)
if mup_vector is not None:
gate_size = self.tped_intermediate_size
states_proj = states_proj * mup_vector[gate_size:]
gate = gate * mup_vector[:gate_size]
# 2. Prepare output buffer for conv + SSM
ssm_output = torch.empty(
[
hidden_states.shape[0],
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# 3. conv + SSM on TPC — overlaps with GEMM 2 on MME
self.conv_ssm_forward(states_proj, ssm_output)
# 4. gated MLP (needs both gate from GEMM 2 and ssm_output)
hidden_states_varlen = self.norm(ssm_output, gate)
# 5. Final linear projection
output, _ = self.out_proj(hidden_states_varlen)
if get_forward_context().attn_metadata.is_prompt:
output = output.view(1, output.shape[0], output.shape[1])
else:
output = output.view(output.shape[0], 1, output.shape[1])
return output
# ------------------------------------------------------------------
# Pre-clone weight slices as standalone contiguous tensors so that
# F.linear sees them as independent parameters. The Habana bridge
# recognises F.linear and maps it to an optimised MME recipe that
# does NOT require a separate TPC transpose of the weight, unlike
# a raw torch.mm or a non-contiguous view.
#
# Must be called AFTER checkpoint weights have been loaded into
# self.in_proj.weight and BEFORE PT_COMPILE_ONLY_MODE warmup,
# because .clone() does not copy data in compile-only mode.
# Called from apply_model_specific_patches() in hpu_model_runner.
# ------------------------------------------------------------------
def _init_split_weights(self):
gate_size = self.tped_intermediate_size
w = self.in_proj.weight # [total_out, hidden_size]
b = self.in_proj.bias # [total_out] or None
self._states_weight = w[gate_size:].clone() # [states_out, hidden]
self._gate_weight = w[:gate_size].clone() # [gate_out, hidden]
if b is not None:
self._states_bias = b[gate_size:].clone()
self._gate_bias = b[:gate_size].clone()
else:
self._states_bias = None
self._gate_bias = None
self._split_weights_ready = True
def conv_ssm_forward(
self,
states_proj: torch.Tensor,
output: torch.Tensor,
):
# states_proj contains [x, B, C, dt] (gate already split off).
hidden_states_B_C, dt = torch.split(
states_proj,
[self.tped_conv_size, self.tped_dt_size],
dim=-1,
)
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
assert self.cache_config is not None
enable_prefix_caching = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
self_kv_cache = self.kv_cache
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0]
ssm_state = self_kv_cache[1]
load_indices_tensor = attn_metadata.load_indices_tensor[self.cache_group_idx]
store_indices_tensor = attn_metadata.store_indices_tensor[self.cache_group_idx]
if enable_prefix_caching and attn_metadata.is_prompt:
blocks_caching_range = attn_metadata.blocks_caching_range[self.cache_group_idx]
mamba_chunks_to_block_mapping = attn_metadata.mamba_chunks_to_block_mapping[self.cache_group_idx]
seqlens_offsets_for_blocks = attn_metadata.seqlens_offsets_for_blocks
else:
blocks_caching_range = None
mamba_chunks_to_block_mapping = None
seqlens_offsets_for_blocks = None
has_initial_states_p = attn_metadata.has_initial_states_p
# is below sufficient to get chunk_size or does it need to passed via metadata
assert self.model_config is not None
chunk_size = self.model_config.get_mamba_chunk_size()
query_start_loc_p = attn_metadata.query_start_loc_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
padding_mask_flat = attn_metadata.padding_mask_flat
if attn_metadata is None:
# profile run
hidden_states_B_C = (hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
return hidden_states
has_prefill = attn_metadata.is_prompt
has_decode = not attn_metadata.is_prompt
# Process prefill requests
if has_prefill:
assert padding_mask_flat is not None
x = hidden_states_B_C.transpose(0, 1)
hidden_states_B_C = hidden_states_B_C * padding_mask_flat
dt = dt * padding_mask_flat
hidden_states_B_C = granite_causal_conv1d_fn(
x,
self.conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
enable_prefix_caching=enable_prefix_caching,
load_cache_indices=load_indices_tensor,
store_cache_indices=store_indices_tensor,
blocks_caching_range=blocks_caching_range,
seqlens_offsets_for_blocks=seqlens_offsets_for_blocks,
metadata=attn_metadata,
query_start_loc=query_start_loc_p,
is_prompt=True,
).transpose(0, 1)
hidden_states_B_C = hidden_states_B_C * padding_mask_flat
hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(hidden_states_B_C)
# 3. State Space Model sequence transformation
initial_states = None
if attn_metadata.prep_initial_states:
initial_states = ssm_state[load_indices_tensor]
# NOTE: final output is an in-place update of out tensor
varlen_states = hpu_mamba_chunk_scan_combined_varlen(
hidden_states_p.view(hidden_states_p.shape[0], self.num_heads // self.tp_size, self.head_dim),
dt,
self.A,
B_p.view(B_p.shape[0], self.n_groups // self.tp_size, -1),
C_p.view(C_p.shape[0], self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
cu_seqlens=query_start_loc_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=output.view(output.shape[0], -1, self.head_dim),
state_dtype=ssm_state.dtype,
padding_mask=padding_mask_flat,
)
output = output * padding_mask_flat.view(output.shape[0], 1)
if enable_prefix_caching:
ssm_state[mamba_chunks_to_block_mapping] = varlen_states
else:
ssm_state[store_indices_tensor] = varlen_states[last_chunk_indices_p]
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C = granite_causal_conv1d_update(
hidden_states_B_C,
conv_state,
self.conv_weights,
self.conv1d.bias,
self.activation,
load_cache_indices=load_indices_tensor,
store_cache_indices=store_indices_tensor,
query_start_loc=query_start_loc_p,
)
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(hidden_states_B_C)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A_d = self.A.to(dtype=torch.float32) # (nheads,) — keep compact, no expand
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(-1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor
# NOTE: final output is an in-place update of out tensor
hpu_selective_state_update = get_selective_state_update_impl()
hpu_selective_state_update(
ssm_state,
hidden_states_d,
dt,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=load_indices_tensor,
dst_state_batch_indices=store_indices_tensor,
out=output.view(output.shape[0], -1, self.head_dim),
)