Skip to content

vllm_gaudi.ops.hpu_mamba_mixer2

HPUMambaMixer2

Bases: MambaMixer2

Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
@MambaMixer2.register_oot
class HPUMambaMixer2(MambaMixer2):

    def __init__(
        self,
        hidden_size: int,
        ssm_state_size: int,
        conv_kernel_size: int,
        intermediate_size: int,
        use_conv_bias: bool,
        use_bias: bool,
        n_groups: int = 1,
        num_heads: int = 128,
        head_dim: int = 64,
        rms_norm_eps: float = 1e-5,
        activation: str = "silu",
        use_rms_norm: bool = True,
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super(MambaMixer2, self).__init__()

        self.tp_size = get_tensor_model_parallel_world_size()

        assert num_heads % self.tp_size == 0, ("Tensor parallel world size must divide num heads.")

        assert (n_groups %
                self.tp_size) == 0 or n_groups == 1, ("If tensor parallel world size does not divide num_groups, "
                                                      "then num_groups must equal 1.")

        assert n_groups % self.tp_size == 0

        self.ssm_state_size = ssm_state_size
        self.conv_kernel_size = conv_kernel_size
        self.activation = activation

        self.intermediate_size = intermediate_size
        self.head_dim = head_dim
        self.num_heads = num_heads
        self.n_groups = n_groups

        self.num_spec = get_current_vllm_config().num_speculative_tokens

        self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
        self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size

        self.conv1d = MergedColumnParallelLinear(
            input_size=conv_kernel_size,
            output_sizes=[
                intermediate_size,
                self.groups_ssm_state_size,
                self.groups_ssm_state_size,
            ],
            bias=use_conv_bias,
            quant_config=None,
            prefix=f"{prefix}.conv1d",
        )

        self.in_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[
                intermediate_size,
                intermediate_size,
                self.groups_ssm_state_size,
                self.groups_ssm_state_size,
                self.num_heads,
            ],
            bias=use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.in_proj",
        )

        # unsqueeze to fit conv1d weights shape into the linear weights shape.
        # Can't do this in `weight_loader` since it already exists in
        # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
        # and `set_weight_attrs` doesn't allow to override it
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
        conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
        self.register_buffer("conv_weights", conv_weights, persistent=False)

        # - these are TPed by heads to reduce the size of the
        #   temporal shape
        self.A = nn.Parameter(torch.empty(
            divide(num_heads, self.tp_size),
            dtype=torch.float32,
        ))
        self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
        self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
        self.use_rms_norm = use_rms_norm

        set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
        a_weight_loader = composed_weight_loader(sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
        set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})

        self.out_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=use_bias,
            input_is_parallel=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )

        self.norm = Mixer2RMSNormGated(intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps)

        # - get hidden_states, B and C after depthwise convolution.
        self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
            hidden_states_B_C,
            [
                self.intermediate_size // self.tp_size,
                self.groups_ssm_state_size // self.tp_size,
                self.groups_ssm_state_size // self.tp_size,
            ],
            dim=-1,
        )

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        # The tuple is (conv_state, ssm_state)
        self.kv_cache = (torch.tensor([]), torch.tensor([]))

        self.model_config = model_config
        self.cache_config = cache_config
        self.prefix = prefix

        # Pre-compute sizes for forward pass
        self.tped_intermediate_size = self.intermediate_size // self.tp_size
        self.tped_conv_size = self.conv_dim // self.tp_size
        self.tped_dt_size = self.num_heads // self.tp_size

        self._split_weights_ready = False

        self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
            hidden_states_B_C,
            [
                self.tped_intermediate_size,
                self.groups_ssm_state_size // self.tp_size,
                self.groups_ssm_state_size // self.tp_size,
            ],
            dim=-1,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        mup_vector: torch.Tensor | None = None,
    ):
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

        # 1. Split in_proj into two GEMMs for TPC/MME pipelining.
        #    GEMM 1 (states: x,B,C,dt) is dispatched to the MME first;
        #    GEMM 2 (gate) is dispatched second.  The Gaudi runtime can
        #    overlap GEMM 2 on the MME with conv+SSM TPC work that
        #    depends only on GEMM 1.
        states_proj = F.linear(hidden_states, self._states_weight, self._states_bias)

        gate = F.linear(hidden_states, self._gate_weight, self._gate_bias)

        if mup_vector is not None:
            gate_size = self.tped_intermediate_size
            states_proj = states_proj * mup_vector[gate_size:]
            gate = gate * mup_vector[:gate_size]

        # 2. Prepare output buffer for conv + SSM
        ssm_output = torch.empty(
            [
                hidden_states.shape[0],
                (self.num_heads // self.tp_size) * self.head_dim,
            ],
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )

        # 3. conv + SSM on TPC — overlaps with GEMM 2 on MME
        self.conv_ssm_forward(states_proj, ssm_output)

        # 4. gated MLP (needs both gate from GEMM 2 and ssm_output)
        hidden_states_varlen = self.norm(ssm_output, gate)

        # 5. Final linear projection
        output, _ = self.out_proj(hidden_states_varlen)

        if get_forward_context().attn_metadata.is_prompt:
            output = output.view(1, output.shape[0], output.shape[1])
        else:
            output = output.view(output.shape[0], 1, output.shape[1])

        return output

    # ------------------------------------------------------------------
    # Pre-clone weight slices as standalone contiguous tensors so that
    # F.linear sees them as independent parameters.  The Habana bridge
    # recognises F.linear and maps it to an optimised MME recipe that
    # does NOT require a separate TPC transpose of the weight, unlike
    # a raw torch.mm or a non-contiguous view.
    #
    # Must be called AFTER checkpoint weights have been loaded into
    # self.in_proj.weight and BEFORE PT_COMPILE_ONLY_MODE warmup,
    # because .clone() does not copy data in compile-only mode.
    # Called from apply_model_specific_patches() in hpu_model_runner.
    # ------------------------------------------------------------------
    def _init_split_weights(self):
        gate_size = self.tped_intermediate_size
        w = self.in_proj.weight  # [total_out, hidden_size]
        b = self.in_proj.bias  # [total_out] or None

        self._states_weight = w[gate_size:].clone()  # [states_out, hidden]
        self._gate_weight = w[:gate_size].clone()  # [gate_out, hidden]

        if b is not None:
            self._states_bias = b[gate_size:].clone()
            self._gate_bias = b[:gate_size].clone()
        else:
            self._states_bias = None
            self._gate_bias = None

        self._split_weights_ready = True

    def conv_ssm_forward(
        self,
        states_proj: torch.Tensor,
        output: torch.Tensor,
    ):
        # states_proj contains [x, B, C, dt] (gate already split off).
        hidden_states_B_C, dt = torch.split(
            states_proj,
            [self.tped_conv_size, self.tped_dt_size],
            dim=-1,
        )

        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        assert self.cache_config is not None
        enable_prefix_caching = self.cache_config.enable_prefix_caching
        if attn_metadata is not None:
            self_kv_cache = self.kv_cache
            # conv_state = (..., dim, width-1) yet contiguous along 'dim'
            conv_state = self_kv_cache[0]
            ssm_state = self_kv_cache[1]

            load_indices_tensor = attn_metadata.load_indices_tensor[self.cache_group_idx]
            store_indices_tensor = attn_metadata.store_indices_tensor[self.cache_group_idx]
            if enable_prefix_caching and attn_metadata.is_prompt:
                blocks_caching_range = attn_metadata.blocks_caching_range[self.cache_group_idx]
                mamba_chunks_to_block_mapping = attn_metadata.mamba_chunks_to_block_mapping[self.cache_group_idx]
                seqlens_offsets_for_blocks = attn_metadata.seqlens_offsets_for_blocks
            else:
                blocks_caching_range = None
                mamba_chunks_to_block_mapping = None
                seqlens_offsets_for_blocks = None

            has_initial_states_p = attn_metadata.has_initial_states_p
            # is below sufficient to get chunk_size or does it need to passed via metadata
            assert self.model_config is not None
            chunk_size = self.model_config.get_mamba_chunk_size()
            query_start_loc_p = attn_metadata.query_start_loc_p
            last_chunk_indices_p = attn_metadata.last_chunk_indices_p
            padding_mask_flat = attn_metadata.padding_mask_flat

        if attn_metadata is None:
            # profile run
            hidden_states_B_C = (hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)).contiguous()
            hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
            return hidden_states

        has_prefill = attn_metadata.is_prompt
        has_decode = not attn_metadata.is_prompt

        # Process prefill requests
        if has_prefill:
            assert padding_mask_flat is not None
            x = hidden_states_B_C.transpose(0, 1)
            hidden_states_B_C = hidden_states_B_C * padding_mask_flat
            dt = dt * padding_mask_flat

            hidden_states_B_C = granite_causal_conv1d_fn(
                x,
                self.conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_states_p,
                enable_prefix_caching=enable_prefix_caching,
                load_cache_indices=load_indices_tensor,
                store_cache_indices=store_indices_tensor,
                blocks_caching_range=blocks_caching_range,
                seqlens_offsets_for_blocks=seqlens_offsets_for_blocks,
                metadata=attn_metadata,
                query_start_loc=query_start_loc_p,
                is_prompt=True,
            ).transpose(0, 1)

            hidden_states_B_C = hidden_states_B_C * padding_mask_flat
            hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(hidden_states_B_C)

            # 3. State Space Model sequence transformation
            initial_states = None
            if attn_metadata.prep_initial_states:
                initial_states = ssm_state[load_indices_tensor]

            # NOTE: final output is an in-place update of out tensor
            varlen_states = hpu_mamba_chunk_scan_combined_varlen(
                hidden_states_p.view(hidden_states_p.shape[0], self.num_heads // self.tp_size, self.head_dim),
                dt,
                self.A,
                B_p.view(B_p.shape[0], self.n_groups // self.tp_size, -1),
                C_p.view(C_p.shape[0], self.n_groups // self.tp_size, -1),
                chunk_size=chunk_size,
                D=self.D,
                z=None,
                dt_bias=self.dt_bias,
                cu_seqlens=query_start_loc_p,
                last_chunk_indices=last_chunk_indices_p,
                initial_states=initial_states,
                dt_softplus=True,
                dt_limit=(0.0, float("inf")),
                out=output.view(output.shape[0], -1, self.head_dim),
                state_dtype=ssm_state.dtype,
                padding_mask=padding_mask_flat,
            )
            output = output * padding_mask_flat.view(output.shape[0], 1)

            if enable_prefix_caching:
                ssm_state[mamba_chunks_to_block_mapping] = varlen_states
            else:
                ssm_state[store_indices_tensor] = varlen_states[last_chunk_indices_p]

        # Process decode requests
        if has_decode:
            # 2. Convolution sequence transformation
            hidden_states_B_C = granite_causal_conv1d_update(
                hidden_states_B_C,
                conv_state,
                self.conv_weights,
                self.conv1d.bias,
                self.activation,
                load_cache_indices=load_indices_tensor,
                store_cache_indices=store_indices_tensor,
                query_start_loc=query_start_loc_p,
            )

            hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(hidden_states_B_C)

            # 3. State Space Model sequence transformation
            n_groups = self.n_groups // self.tp_size
            A_d = self.A.to(dtype=torch.float32)  # (nheads,) — keep compact, no expand
            dt = dt[:, :, None].expand(-1, -1, self.head_dim)
            dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
            D_d = self.D[:, None, ...].expand(-1, self.head_dim)
            B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
            C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
            hidden_states_d = hidden_states_d.view(-1, self.num_heads // self.tp_size, self.head_dim)

            # - the hidden is reshaped into (bs, num_heads, head_dim)
            # - mamba_cache_params.ssm_state's slots will be selected
            #   using state_indices_tensor
            # NOTE: final output is an in-place update of out tensor
            hpu_selective_state_update = get_selective_state_update_impl()
            hpu_selective_state_update(
                ssm_state,
                hidden_states_d,
                dt,
                A_d,
                B_d,
                C_d,
                D_d,
                z=None,
                dt_bias=dt_bias,
                dt_softplus=True,
                state_batch_indices=load_indices_tensor,
                dst_state_batch_indices=store_indices_tensor,
                out=output.view(output.shape[0], -1, self.head_dim),
            )

A instance-attribute

A = Parameter(
    empty(divide(num_heads, tp_size), dtype=float32)
)

D instance-attribute

D = Parameter(ones(num_heads // tp_size))

_split_weights_ready instance-attribute

_split_weights_ready = False

activation instance-attribute

activation = activation

cache_config instance-attribute

cache_config = cache_config

conv1d instance-attribute

conv1d = MergedColumnParallelLinear(
    input_size=conv_kernel_size,
    output_sizes=[
        intermediate_size,
        groups_ssm_state_size,
        groups_ssm_state_size,
    ],
    bias=use_conv_bias,
    quant_config=None,
    prefix=f"{prefix}.conv1d",
)

conv_dim instance-attribute

conv_dim = intermediate_size + 2 * groups_ssm_state_size

conv_kernel_size instance-attribute

conv_kernel_size = conv_kernel_size

dt_bias instance-attribute

dt_bias = Parameter(ones(num_heads // tp_size))

groups_ssm_state_size instance-attribute

groups_ssm_state_size = n_groups * ssm_state_size

head_dim instance-attribute

head_dim = head_dim

in_proj instance-attribute

in_proj = MergedColumnParallelLinear(
    input_size=hidden_size,
    output_sizes=[
        intermediate_size,
        intermediate_size,
        groups_ssm_state_size,
        groups_ssm_state_size,
        num_heads,
    ],
    bias=use_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.in_proj",
)

intermediate_size instance-attribute

intermediate_size = intermediate_size

kv_cache instance-attribute

kv_cache = (tensor([]), tensor([]))

model_config instance-attribute

model_config = model_config

n_groups instance-attribute

n_groups = n_groups

norm instance-attribute

norm = Mixer2RMSNormGated(
    intermediate_size,
    n_groups,
    use_rms_norm,
    eps=rms_norm_eps,
)

num_heads instance-attribute

num_heads = num_heads

num_spec instance-attribute

num_spec = num_speculative_tokens

out_proj instance-attribute

out_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=use_bias,
    input_is_parallel=True,
    quant_config=quant_config,
    prefix=f"{prefix}.out_proj",
)

prefix instance-attribute

prefix = prefix

split_hidden_states_B_C_fn instance-attribute

split_hidden_states_B_C_fn = lambda hidden_states_B_C: (
    split(
        hidden_states_B_C,
        [
            tped_intermediate_size,
            groups_ssm_state_size // tp_size,
            groups_ssm_state_size // tp_size,
        ],
        dim=-1,
    )
)

ssm_state_size instance-attribute

ssm_state_size = ssm_state_size

tp_size instance-attribute

tp_size = get_tensor_model_parallel_world_size()

tped_conv_size instance-attribute

tped_conv_size = conv_dim // tp_size

tped_dt_size instance-attribute

tped_dt_size = num_heads // tp_size

tped_intermediate_size instance-attribute

tped_intermediate_size = intermediate_size // tp_size

use_rms_norm instance-attribute

use_rms_norm = use_rms_norm

__init__

__init__(
    hidden_size: int,
    ssm_state_size: int,
    conv_kernel_size: int,
    intermediate_size: int,
    use_conv_bias: bool,
    use_bias: bool,
    n_groups: int = 1,
    num_heads: int = 128,
    head_dim: int = 64,
    rms_norm_eps: float = 1e-05,
    activation: str = "silu",
    use_rms_norm: bool = True,
    model_config: ModelConfig | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
)
Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
def __init__(
    self,
    hidden_size: int,
    ssm_state_size: int,
    conv_kernel_size: int,
    intermediate_size: int,
    use_conv_bias: bool,
    use_bias: bool,
    n_groups: int = 1,
    num_heads: int = 128,
    head_dim: int = 64,
    rms_norm_eps: float = 1e-5,
    activation: str = "silu",
    use_rms_norm: bool = True,
    model_config: ModelConfig | None = None,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
):
    super(MambaMixer2, self).__init__()

    self.tp_size = get_tensor_model_parallel_world_size()

    assert num_heads % self.tp_size == 0, ("Tensor parallel world size must divide num heads.")

    assert (n_groups %
            self.tp_size) == 0 or n_groups == 1, ("If tensor parallel world size does not divide num_groups, "
                                                  "then num_groups must equal 1.")

    assert n_groups % self.tp_size == 0

    self.ssm_state_size = ssm_state_size
    self.conv_kernel_size = conv_kernel_size
    self.activation = activation

    self.intermediate_size = intermediate_size
    self.head_dim = head_dim
    self.num_heads = num_heads
    self.n_groups = n_groups

    self.num_spec = get_current_vllm_config().num_speculative_tokens

    self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
    self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size

    self.conv1d = MergedColumnParallelLinear(
        input_size=conv_kernel_size,
        output_sizes=[
            intermediate_size,
            self.groups_ssm_state_size,
            self.groups_ssm_state_size,
        ],
        bias=use_conv_bias,
        quant_config=None,
        prefix=f"{prefix}.conv1d",
    )

    self.in_proj = MergedColumnParallelLinear(
        input_size=hidden_size,
        output_sizes=[
            intermediate_size,
            intermediate_size,
            self.groups_ssm_state_size,
            self.groups_ssm_state_size,
            self.num_heads,
        ],
        bias=use_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.in_proj",
    )

    # unsqueeze to fit conv1d weights shape into the linear weights shape.
    # Can't do this in `weight_loader` since it already exists in
    # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
    # and `set_weight_attrs` doesn't allow to override it
    self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
    conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
    self.register_buffer("conv_weights", conv_weights, persistent=False)

    # - these are TPed by heads to reduce the size of the
    #   temporal shape
    self.A = nn.Parameter(torch.empty(
        divide(num_heads, self.tp_size),
        dtype=torch.float32,
    ))
    self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
    self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
    self.use_rms_norm = use_rms_norm

    set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
    a_weight_loader = composed_weight_loader(sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
    set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
    set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})

    self.out_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=use_bias,
        input_is_parallel=True,
        quant_config=quant_config,
        prefix=f"{prefix}.out_proj",
    )

    self.norm = Mixer2RMSNormGated(intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps)

    # - get hidden_states, B and C after depthwise convolution.
    self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
        hidden_states_B_C,
        [
            self.intermediate_size // self.tp_size,
            self.groups_ssm_state_size // self.tp_size,
            self.groups_ssm_state_size // self.tp_size,
        ],
        dim=-1,
    )

    compilation_config = get_current_vllm_config().compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self
    # The tuple is (conv_state, ssm_state)
    self.kv_cache = (torch.tensor([]), torch.tensor([]))

    self.model_config = model_config
    self.cache_config = cache_config
    self.prefix = prefix

    # Pre-compute sizes for forward pass
    self.tped_intermediate_size = self.intermediate_size // self.tp_size
    self.tped_conv_size = self.conv_dim // self.tp_size
    self.tped_dt_size = self.num_heads // self.tp_size

    self._split_weights_ready = False

    self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
        hidden_states_B_C,
        [
            self.tped_intermediate_size,
            self.groups_ssm_state_size // self.tp_size,
            self.groups_ssm_state_size // self.tp_size,
        ],
        dim=-1,
    )

_init_split_weights

_init_split_weights()
Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
def _init_split_weights(self):
    gate_size = self.tped_intermediate_size
    w = self.in_proj.weight  # [total_out, hidden_size]
    b = self.in_proj.bias  # [total_out] or None

    self._states_weight = w[gate_size:].clone()  # [states_out, hidden]
    self._gate_weight = w[:gate_size].clone()  # [gate_out, hidden]

    if b is not None:
        self._states_bias = b[gate_size:].clone()
        self._gate_bias = b[:gate_size].clone()
    else:
        self._states_bias = None
        self._gate_bias = None

    self._split_weights_ready = True

conv_ssm_forward

conv_ssm_forward(states_proj: Tensor, output: Tensor)
Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
def conv_ssm_forward(
    self,
    states_proj: torch.Tensor,
    output: torch.Tensor,
):
    # states_proj contains [x, B, C, dt] (gate already split off).
    hidden_states_B_C, dt = torch.split(
        states_proj,
        [self.tped_conv_size, self.tped_dt_size],
        dim=-1,
    )

    forward_context = get_forward_context()
    attn_metadata: AttentionMetadata = forward_context.attn_metadata

    assert self.cache_config is not None
    enable_prefix_caching = self.cache_config.enable_prefix_caching
    if attn_metadata is not None:
        self_kv_cache = self.kv_cache
        # conv_state = (..., dim, width-1) yet contiguous along 'dim'
        conv_state = self_kv_cache[0]
        ssm_state = self_kv_cache[1]

        load_indices_tensor = attn_metadata.load_indices_tensor[self.cache_group_idx]
        store_indices_tensor = attn_metadata.store_indices_tensor[self.cache_group_idx]
        if enable_prefix_caching and attn_metadata.is_prompt:
            blocks_caching_range = attn_metadata.blocks_caching_range[self.cache_group_idx]
            mamba_chunks_to_block_mapping = attn_metadata.mamba_chunks_to_block_mapping[self.cache_group_idx]
            seqlens_offsets_for_blocks = attn_metadata.seqlens_offsets_for_blocks
        else:
            blocks_caching_range = None
            mamba_chunks_to_block_mapping = None
            seqlens_offsets_for_blocks = None

        has_initial_states_p = attn_metadata.has_initial_states_p
        # is below sufficient to get chunk_size or does it need to passed via metadata
        assert self.model_config is not None
        chunk_size = self.model_config.get_mamba_chunk_size()
        query_start_loc_p = attn_metadata.query_start_loc_p
        last_chunk_indices_p = attn_metadata.last_chunk_indices_p
        padding_mask_flat = attn_metadata.padding_mask_flat

    if attn_metadata is None:
        # profile run
        hidden_states_B_C = (hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)).contiguous()
        hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
        return hidden_states

    has_prefill = attn_metadata.is_prompt
    has_decode = not attn_metadata.is_prompt

    # Process prefill requests
    if has_prefill:
        assert padding_mask_flat is not None
        x = hidden_states_B_C.transpose(0, 1)
        hidden_states_B_C = hidden_states_B_C * padding_mask_flat
        dt = dt * padding_mask_flat

        hidden_states_B_C = granite_causal_conv1d_fn(
            x,
            self.conv_weights,
            self.conv1d.bias,
            activation=self.activation,
            conv_states=conv_state,
            has_initial_state=has_initial_states_p,
            enable_prefix_caching=enable_prefix_caching,
            load_cache_indices=load_indices_tensor,
            store_cache_indices=store_indices_tensor,
            blocks_caching_range=blocks_caching_range,
            seqlens_offsets_for_blocks=seqlens_offsets_for_blocks,
            metadata=attn_metadata,
            query_start_loc=query_start_loc_p,
            is_prompt=True,
        ).transpose(0, 1)

        hidden_states_B_C = hidden_states_B_C * padding_mask_flat
        hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(hidden_states_B_C)

        # 3. State Space Model sequence transformation
        initial_states = None
        if attn_metadata.prep_initial_states:
            initial_states = ssm_state[load_indices_tensor]

        # NOTE: final output is an in-place update of out tensor
        varlen_states = hpu_mamba_chunk_scan_combined_varlen(
            hidden_states_p.view(hidden_states_p.shape[0], self.num_heads // self.tp_size, self.head_dim),
            dt,
            self.A,
            B_p.view(B_p.shape[0], self.n_groups // self.tp_size, -1),
            C_p.view(C_p.shape[0], self.n_groups // self.tp_size, -1),
            chunk_size=chunk_size,
            D=self.D,
            z=None,
            dt_bias=self.dt_bias,
            cu_seqlens=query_start_loc_p,
            last_chunk_indices=last_chunk_indices_p,
            initial_states=initial_states,
            dt_softplus=True,
            dt_limit=(0.0, float("inf")),
            out=output.view(output.shape[0], -1, self.head_dim),
            state_dtype=ssm_state.dtype,
            padding_mask=padding_mask_flat,
        )
        output = output * padding_mask_flat.view(output.shape[0], 1)

        if enable_prefix_caching:
            ssm_state[mamba_chunks_to_block_mapping] = varlen_states
        else:
            ssm_state[store_indices_tensor] = varlen_states[last_chunk_indices_p]

    # Process decode requests
    if has_decode:
        # 2. Convolution sequence transformation
        hidden_states_B_C = granite_causal_conv1d_update(
            hidden_states_B_C,
            conv_state,
            self.conv_weights,
            self.conv1d.bias,
            self.activation,
            load_cache_indices=load_indices_tensor,
            store_cache_indices=store_indices_tensor,
            query_start_loc=query_start_loc_p,
        )

        hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(hidden_states_B_C)

        # 3. State Space Model sequence transformation
        n_groups = self.n_groups // self.tp_size
        A_d = self.A.to(dtype=torch.float32)  # (nheads,) — keep compact, no expand
        dt = dt[:, :, None].expand(-1, -1, self.head_dim)
        dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
        D_d = self.D[:, None, ...].expand(-1, self.head_dim)
        B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
        C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
        hidden_states_d = hidden_states_d.view(-1, self.num_heads // self.tp_size, self.head_dim)

        # - the hidden is reshaped into (bs, num_heads, head_dim)
        # - mamba_cache_params.ssm_state's slots will be selected
        #   using state_indices_tensor
        # NOTE: final output is an in-place update of out tensor
        hpu_selective_state_update = get_selective_state_update_impl()
        hpu_selective_state_update(
            ssm_state,
            hidden_states_d,
            dt,
            A_d,
            B_d,
            C_d,
            D_d,
            z=None,
            dt_bias=dt_bias,
            dt_softplus=True,
            state_batch_indices=load_indices_tensor,
            dst_state_batch_indices=store_indices_tensor,
            out=output.view(output.shape[0], -1, self.head_dim),
        )

forward

forward(
    hidden_states: Tensor, mup_vector: Tensor | None = None
)
Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
def forward(
    self,
    hidden_states: torch.Tensor,
    mup_vector: torch.Tensor | None = None,
):
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

    # 1. Split in_proj into two GEMMs for TPC/MME pipelining.
    #    GEMM 1 (states: x,B,C,dt) is dispatched to the MME first;
    #    GEMM 2 (gate) is dispatched second.  The Gaudi runtime can
    #    overlap GEMM 2 on the MME with conv+SSM TPC work that
    #    depends only on GEMM 1.
    states_proj = F.linear(hidden_states, self._states_weight, self._states_bias)

    gate = F.linear(hidden_states, self._gate_weight, self._gate_bias)

    if mup_vector is not None:
        gate_size = self.tped_intermediate_size
        states_proj = states_proj * mup_vector[gate_size:]
        gate = gate * mup_vector[:gate_size]

    # 2. Prepare output buffer for conv + SSM
    ssm_output = torch.empty(
        [
            hidden_states.shape[0],
            (self.num_heads // self.tp_size) * self.head_dim,
        ],
        dtype=hidden_states.dtype,
        device=hidden_states.device,
    )

    # 3. conv + SSM on TPC — overlaps with GEMM 2 on MME
    self.conv_ssm_forward(states_proj, ssm_output)

    # 4. gated MLP (needs both gate from GEMM 2 and ssm_output)
    hidden_states_varlen = self.norm(ssm_output, gate)

    # 5. Final linear projection
    output, _ = self.out_proj(hidden_states_varlen)

    if get_forward_context().attn_metadata.is_prompt:
        output = output.view(1, output.shape[0], output.shape[1])
    else:
        output = output.view(output.shape[0], 1, output.shape[1])

    return output

HPUMixer2RMSNormGated

Bases: Mixer2RMSNormGated

Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
@Mixer2RMSNormGated.register_oot
class HPUMixer2RMSNormGated(Mixer2RMSNormGated):

    def __init__(
        self,
        full_hidden_size: int,
        full_n_groups: int,
        use_rms_norm: bool = True,
        eps: float = 1e-6,
    ):
        CustomOp.__init__(self)
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.full_hidden_size = full_hidden_size
        self.group_size = full_hidden_size // full_n_groups
        self.per_rank_hidden_size = full_hidden_size // self.tp_size
        self.n_groups = full_hidden_size // self.group_size

        self.variance_epsilon = eps
        self.use_rms_norm = use_rms_norm
        if self.use_rms_norm:
            # Register norm weight only if we're actually applying RMSNorm
            self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
            set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
        else:
            # Avoid checkpoint mismatch by skipping unused parameter
            self.register_parameter("weight", None)
        assert self.full_hidden_size % self.tp_size == 0, ("Tensor parallel world size must divide hidden size.")

    def forward_oot(
        self,
        x: torch.Tensor,
        gate: torch.Tensor,
    ):
        # Three tensor-parallel cases:
        #   1. n_groups is 1
        #      In this case we parallelize along the reduction dim.
        #      Each rank computes a local sum of squares followed by AllReduce
        #   2. tp_size divides n_groups
        #      Each rank only reduces within its local group(s).
        #      No collective ops necessary.
        #   3. The general case can be pretty complicated so we AllGather
        #      the input and then redundantly compute the RMSNorm.
        input_dtype = x.dtype
        x = x * nn.functional.silu(gate.to(torch.float32))
        if not self.use_rms_norm:
            return x.to(input_dtype)

        if self.n_groups == 1:
            if self.tp_size > 1:
                # Compute local sum and then reduce to obtain global sum
                local_sums = x.pow(2).sum(dim=-1, keepdim=True)
                global_sums = tensor_model_parallel_all_reduce(local_sums)
                # Calculate the variance
                count = self.tp_size * x.shape[-1]
                variance = global_sums / count

            else:
                variance = x.pow(2).mean(-1, keepdim=True)
            x = x * torch.rsqrt(variance + self.variance_epsilon)
        else:
            redundant_tp: bool = self.n_groups % self.tp_size != 0
            if redundant_tp:
                # To handle the general case, redundantly apply the variance
                x = tensor_model_parallel_all_gather(x, -1)

            *prefix_dims, hidden_dim = x.shape
            group_count = hidden_dim // self.group_size
            x_grouped = x.view(*prefix_dims, group_count, self.group_size)
            variance = x_grouped.pow(2).mean(-1, keepdim=True)
            x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
            x = x_grouped.view(*prefix_dims, hidden_dim)

            if redundant_tp:
                start = self.per_rank_hidden_size * self.tp_rank
                end = start + self.per_rank_hidden_size
                x = x[..., start:end]

        return self.weight * x.to(input_dtype)

full_hidden_size instance-attribute

full_hidden_size = full_hidden_size

group_size instance-attribute

group_size = full_hidden_size // full_n_groups

n_groups instance-attribute

n_groups = full_hidden_size // group_size

per_rank_hidden_size instance-attribute

per_rank_hidden_size = full_hidden_size // tp_size

tp_rank instance-attribute

tp_rank = get_tensor_model_parallel_rank()

tp_size instance-attribute

tp_size = get_tensor_model_parallel_world_size()

use_rms_norm instance-attribute

use_rms_norm = use_rms_norm

variance_epsilon instance-attribute

variance_epsilon = eps

weight instance-attribute

weight = Parameter(ones(per_rank_hidden_size))

__init__

__init__(
    full_hidden_size: int,
    full_n_groups: int,
    use_rms_norm: bool = True,
    eps: float = 1e-06,
)
Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
def __init__(
    self,
    full_hidden_size: int,
    full_n_groups: int,
    use_rms_norm: bool = True,
    eps: float = 1e-6,
):
    CustomOp.__init__(self)
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.full_hidden_size = full_hidden_size
    self.group_size = full_hidden_size // full_n_groups
    self.per_rank_hidden_size = full_hidden_size // self.tp_size
    self.n_groups = full_hidden_size // self.group_size

    self.variance_epsilon = eps
    self.use_rms_norm = use_rms_norm
    if self.use_rms_norm:
        # Register norm weight only if we're actually applying RMSNorm
        self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
        set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
    else:
        # Avoid checkpoint mismatch by skipping unused parameter
        self.register_parameter("weight", None)
    assert self.full_hidden_size % self.tp_size == 0, ("Tensor parallel world size must divide hidden size.")

forward_oot

forward_oot(x: Tensor, gate: Tensor)
Source code in vllm_gaudi/ops/hpu_mamba_mixer2.py
def forward_oot(
    self,
    x: torch.Tensor,
    gate: torch.Tensor,
):
    # Three tensor-parallel cases:
    #   1. n_groups is 1
    #      In this case we parallelize along the reduction dim.
    #      Each rank computes a local sum of squares followed by AllReduce
    #   2. tp_size divides n_groups
    #      Each rank only reduces within its local group(s).
    #      No collective ops necessary.
    #   3. The general case can be pretty complicated so we AllGather
    #      the input and then redundantly compute the RMSNorm.
    input_dtype = x.dtype
    x = x * nn.functional.silu(gate.to(torch.float32))
    if not self.use_rms_norm:
        return x.to(input_dtype)

    if self.n_groups == 1:
        if self.tp_size > 1:
            # Compute local sum and then reduce to obtain global sum
            local_sums = x.pow(2).sum(dim=-1, keepdim=True)
            global_sums = tensor_model_parallel_all_reduce(local_sums)
            # Calculate the variance
            count = self.tp_size * x.shape[-1]
            variance = global_sums / count

        else:
            variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
    else:
        redundant_tp: bool = self.n_groups % self.tp_size != 0
        if redundant_tp:
            # To handle the general case, redundantly apply the variance
            x = tensor_model_parallel_all_gather(x, -1)

        *prefix_dims, hidden_dim = x.shape
        group_count = hidden_dim // self.group_size
        x_grouped = x.view(*prefix_dims, group_count, self.group_size)
        variance = x_grouped.pow(2).mean(-1, keepdim=True)
        x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
        x = x_grouped.view(*prefix_dims, hidden_dim)

        if redundant_tp:
            start = self.per_rank_hidden_size * self.tp_rank
            end = start + self.per_rank_hidden_size
            x = x[..., start:end]

    return self.weight * x.to(input_dtype)