Skip to content

vllm_gaudi.ops.ops_selector

Selector module to switch between PyTorch and Triton implementations of Mamba operations based on environment variable.

Set VLLM_MAMBA_USE_PYTORCH=1 to use PyTorch implementations. Default (unset or 0) uses optimized Triton implementations.

_USE_PYTORCH module-attribute

_USE_PYTORCH = get('VLLM_MAMBA_USE_PYTORCH', '0') == '1'

_USE_SELECTIVE_STATE_UPDATE_REF module-attribute

_USE_SELECTIVE_STATE_UPDATE_REF = (
    get("VLLM_MAMBA_USE_SELECTIVE_STATE_UPDATE_REF_PT", "1")
    == "1"
)

_use_pytorch_runtime

_use_pytorch_runtime()

Check at runtime whether to use PyTorch implementation.
This allows torch.compile to respect the environment variable.

Source code in vllm_gaudi/ops/ops_selector.py
def _use_pytorch_runtime():
    """Check at runtime whether to use PyTorch implementation.  
    This allows torch.compile to respect the environment variable."""
    return os.environ.get("VLLM_MAMBA_USE_PYTORCH", "0") == "1"

_wrap_selective_state_update_ref

_wrap_selective_state_update_ref(
    selective_state_update_ref_fn,
)

Wrapper to adapt PyTorch selective_state_update_ref to match Triton API.

Source code in vllm_gaudi/ops/ops_selector.py
def _wrap_selective_state_update_ref(selective_state_update_ref_fn):
    """Wrapper to adapt PyTorch selective_state_update_ref to match Triton API."""

    def wrapped(state,
                x,
                dt,
                A,
                B,
                C,
                D=None,
                z=None,
                dt_bias=None,
                dt_softplus=False,
                state_batch_indices=None,
                dst_state_batch_indices=None,
                out=None):
        # PyTorch ref version doesn't support the batch indices parameters
        # These are used in Triton for selective state updates with batching
        if state_batch_indices is not None or dst_state_batch_indices is not None:
            # Triton uses state_batch_indices to select which state slots to read from
            # and dst_state_batch_indices to select which state slots to write to
            # The PyTorch version doesn't support this, so we need to handle it manually

            # When indices are provided, we need to:
            # 1. Select the appropriate state slices based on state_batch_indices
            # 2. Run the update on those slices
            # 3. Write back to the appropriate locations based on dst_state_batch_indices

            if state_batch_indices is None:
                state_batch_indices = torch.arange(x.shape[0], device=x.device)
            if dst_state_batch_indices is None:
                dst_state_batch_indices = state_batch_indices

            # Use index_select instead of fancy indexing to dispatch a
            # dedicated gather kernel rather than gather_nd TPC kernel.
            selected_state = torch.index_select(state, 0, state_batch_indices.long())

            # Run the update
            result = selective_state_update_ref_fn(selected_state,
                                                   x,
                                                   dt,
                                                   A,
                                                   B,
                                                   C,
                                                   D=D,
                                                   z=z,
                                                   dt_bias=dt_bias,
                                                   dt_softplus=dt_softplus)

            # Use index_copy_ instead of fancy indexing to dispatch a
            # dedicated copy kernel rather than scatter_nd TPC kernel.
            state.index_copy_(0, dst_state_batch_indices.long(), selected_state)

            # Handle output
            if out is not None:
                out.copy_(result)
                return out
            else:
                return result
        else:
            # No batch indices, use the simple path
            result = selective_state_update_ref_fn(state,
                                                   x,
                                                   dt,
                                                   A,
                                                   B,
                                                   C,
                                                   D=D,
                                                   z=z,
                                                   dt_bias=dt_bias,
                                                   dt_softplus=dt_softplus)

            # If out is provided, copy result into it (to match Triton's in-place behavior)
            if out is not None:
                out.copy_(result)
                return out
            else:
                return result

    return wrapped

get_selective_state_update_impl

get_selective_state_update_impl()

Returns the selective state update implementation.

PyTorch version signature

selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False) Returns: output tensor

Source code in vllm_gaudi/ops/ops_selector.py
def get_selective_state_update_impl():
    """
    Returns the selective state update implementation.

    PyTorch version signature:
        selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, 
                                   dt_bias=None, dt_softplus=False)
        Returns: output tensor

    """
    # Import both implementations
    from .pytorch_implementation import selective_state_update_ref

    # Create wrapped PyTorch version
    pytorch_wrapped = _wrap_selective_state_update_ref(selective_state_update_ref)

    # Return a runtime dispatcher
    def dispatcher(state,
                   x,
                   dt,
                   A,
                   B,
                   C,
                   D=None,
                   z=None,
                   dt_bias=None,
                   dt_softplus=False,
                   state_batch_indices=None,
                   dst_state_batch_indices=None,
                   out=None):
        return pytorch_wrapped(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices,
                               dst_state_batch_indices, out)

    return dispatcher

use_pytorch_ops

use_pytorch_ops() -> bool

Returns True if PyTorch implementations should be used.

Source code in vllm_gaudi/ops/ops_selector.py
def use_pytorch_ops() -> bool:
    """Returns True if PyTorch implementations should be used."""
    return _USE_PYTORCH

use_pytorch_selective_state_update_ref

use_pytorch_selective_state_update_ref() -> bool
Source code in vllm_gaudi/ops/ops_selector.py
def use_pytorch_selective_state_update_ref() -> bool:
    return _USE_SELECTIVE_STATE_UPDATE_REF