Bases: PluggableLayer, MambaBase
Base class for Linear attention layer.
Source code in vllm/model_executor/layers/mamba/linear/base.py
| class LinearAttention(PluggableLayer, MambaBase):
"""Base class for Linear attention layer."""
def __init__(
self, config: PretrainedConfig, vllm_config: VllmConfig, prefix: str = ""
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.prefix = prefix
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.BLOCK = getattr(config, "block", 256)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_hidden_layers = config.num_hidden_layers
self.head_dim = (
config.head_dim
if hasattr(config, "head_dim")
else config.hidden_size // self.num_heads
)
self.hidden_inner_size = self.head_dim * self.num_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.num_heads % self.tp_size == 0
@property
def mamba_type(self) -> MambaAttentionBackendEnum:
return MambaAttentionBackendEnum.LINEAR
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim
)
|