Skip to content

vllm_gaudi.ops.ssd_combined

_mamba_chunk_scan_combined_fwd

_mamba_chunk_scan_combined_fwd(
    x,
    dt,
    A,
    B,
    C,
    chunk_size,
    out,
    D=None,
    z=None,
    dt_bias=None,
    initial_states=None,
    cu_seqlens=None,
    last_chunk_indices=None,
    dt_softplus=False,
    dt_limit=(0.0, float("inf")),
    state_dtype=None,
    padding_mask=None,
)
Source code in vllm_gaudi/ops/ssd_combined.py
def _mamba_chunk_scan_combined_fwd(
        x,
        dt,
        A,
        B,
        C,
        chunk_size,
        out,
        D=None,
        z=None,
        dt_bias=None,
        initial_states=None,
        cu_seqlens=None,
        last_chunk_indices=None,
        dt_softplus=False,
        dt_limit=(0.0, float("inf")),
        state_dtype=None,
        padding_mask=None,
):
    assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
    seqlen, nheads, headdim = x.shape
    _, ngroups, dstate = B.shape
    assert nheads % ngroups == 0
    assert B.shape == (seqlen, ngroups, dstate)
    assert dt.shape == (seqlen, nheads)
    assert A.shape == (nheads, )
    assert C.shape == B.shape
    if z is not None:
        assert z.shape == x.shape
    if D is not None:
        assert D.shape == (nheads, headdim) or D.shape == (nheads, )
    if B.stride(-1) != 1:
        B = B.contiguous()
    if C.stride(-1) != 1:
        C = C.contiguous()
    if (x.stride(-1) != 1 and x.stride(0) != 1):  # Either M or K dimension should be contiguous
        x = x.contiguous()
    if (z is not None and z.stride(-1) != 1 and z.stride(0) != 1):  # Either M or K dimension should be contiguous
        z = z.contiguous()
    if D is not None and D.stride(-1) != 1:
        D = D.contiguous()
    assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"

    if initial_states is not None:
        assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)

    # This function executes 5 sub-functions for computing mamba
    # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
    #   which has a minimal implementation to understand the below operations
    # - as explained by the blog, mamba is a special case of causal attention
    # - the idea is to chunk the attention matrix and compute each
    #   submatrix separately using different optimizations.
    # - see the blog and paper for a visualization of the submatrices
    #   which we refer to in the comments below

    # 1. Compute chunked cumsum of A * dt
    # - here dt may go through a softplus activation
    dA_cumsum, dt = new_chunk_cumsum(
        dt,
        A,
        chunk_size,
        dt_bias=dt_bias,
        dt_softplus=dt_softplus,
        dt_limit=dt_limit,
        padding_mask=padding_mask,
    )

    nchunks = seqlen // chunk_size
    nheads_ngroups_ratio = nheads // ngroups
    dt_t = dt.transpose(0, 1)  # (nchunks, nheads, chunk_size)
    dA_cumsum_t = dA_cumsum.transpose(0, 1)  # (nchunks, nheads, chunk_size)
    x_chunked = x.view(nchunks, chunk_size, nheads, headdim)
    B_expanded = B.view(nchunks, chunk_size, ngroups, 1, dstate).expand(-1, -1, -1, nheads_ngroups_ratio,
                                                                        -1).reshape(nchunks, chunk_size, nheads, dstate)

    # 2. Compute the state for each intra-chunk
    # (right term of low-rank factorization of off-diagonal blocks; B terms)
    states = new_chunk_state(B_expanded, x_chunked, dt_t, dA_cumsum_t, states_in_fp32=True)

    # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
    # (middle term of factorization of off-diag blocks; A terms)
    # - for handling chunked prefill, this requires initial_states
    states = new_ssd_state_passing(
        states.flatten(-2),
        dA_cumsum,  # (nheads, nchunks, chunk_size)
        initial_states=initial_states.flatten(-2)
        if initial_states is not None else None,  # (batch, nheads, headdim*dstate)
        out_dtype=state_dtype if state_dtype is not None else C.dtype,
    )
    states = states.view(states.shape[0], states.shape[1], -1, dstate)

    # 4. Compute batched matrix multiply for C_j^T B_i terms
    CB = new_ssd_bmm(C, B, chunk_size, causal=True, output_dtype=torch.float32)

    # 5. Scan and compute the diagonal blocks, taking into
    #    account past causal states.
    # - if initial states are provided, then states information will be
    #   augmented with initial_states.
    # - to do this properly, we need to account for example changes in
    #   the continuous batch, therefore we introduce pseudo chunks, which is
    #   a chunk that is split up each time an example changes.
    new_chunk_scan(
        CB,
        x_chunked,
        dt_t,
        dA_cumsum_t,
        C,
        states,
        out,  # in-place update
        D=D,
        z=z,
        initial_states=initial_states,
    )

    return states

hpu_mamba_chunk_scan_combined_varlen

hpu_mamba_chunk_scan_combined_varlen(
    x,
    dt,
    A,
    B,
    C,
    chunk_size,
    cu_seqlens,
    last_chunk_indices,
    out,
    D=None,
    z=None,
    dt_bias=None,
    initial_states=None,
    dt_softplus=False,
    dt_limit=(0.0, float("inf")),
    state_dtype=None,
    padding_mask=None,
)
Argument

x: (seqlen, nheads, headdim) dt: (seqlen, nheads) A: (nheads) B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int cu_seqlens: (batch + 1,) last_chunk_indices: (batch,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt out: (seqlen, nheads, headdim) preallocated output tensor state_dtype: The data type of the ssm state

Return: varlen_states: (batch, nheads, headdim, dstate)

Source code in vllm_gaudi/ops/ssd_combined.py
def hpu_mamba_chunk_scan_combined_varlen(
        x,
        dt,
        A,
        B,
        C,
        chunk_size,
        cu_seqlens,
        last_chunk_indices,
        out,
        D=None,
        z=None,
        dt_bias=None,
        initial_states=None,
        dt_softplus=False,
        dt_limit=(0.0, float("inf")),
        state_dtype=None,
        padding_mask=None,
):
    """
    Argument:
        x: (seqlen, nheads, headdim)
        dt: (seqlen, nheads)
        A: (nheads)
        B: (seqlen, ngroups, dstate)
        C: (seqlen, ngroups, dstate)
        chunk_size: int
        cu_seqlens: (batch + 1,)
        last_chunk_indices: (batch,)
        out: (seqlen, nheads, headdim) preallocated output tensor
        D: (nheads, headdim) or (nheads,)
        z: (seqlen, nheads, headdim)
        dt_bias: (nheads,)
        initial_states: (batch, nheads, headdim, dstate)
        dt_softplus: Whether to apply softplus to dt
        out: (seqlen, nheads, headdim) preallocated output tensor
        state_dtype: The data type of the ssm state
    Return:
        varlen_states: (batch, nheads, headdim, dstate)
    """

    assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"

    varlen_states = _mamba_chunk_scan_combined_fwd(
        x,
        dt,
        A,
        B,
        C,
        chunk_size,
        out,
        D=D,
        z=z,
        dt_bias=dt_bias,
        initial_states=initial_states,
        cu_seqlens=cu_seqlens,
        last_chunk_indices=last_chunk_indices,
        dt_softplus=dt_softplus,
        dt_limit=dt_limit,
        state_dtype=state_dtype,
        padding_mask=padding_mask,
    )

    return varlen_states

is_int_pow_2

is_int_pow_2(n)
Source code in vllm_gaudi/ops/ssd_combined.py
def is_int_pow_2(n):
    return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0