Bases: BaseModel
SpinQuant needs to know the entire architecture of the model, as R1, R2, R3, and R4 rotations need to be applied to specific layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
Parameters:
| Name | Type | Description | Default |
embedding | | name or regex of embedding layer | required |
attn_q | | name or regex of q_proj layer in attention block | required |
attn_k | | name or regex of k_proj layer in attention block | required |
attn_v | | name or regex of v_proj layer in attention block | required |
attn_o | | name or regex of o_proj layer in attention block | required |
attn_head_dim | | head_dim of the attention module, needed because R2 needs to be applied "head-wisely" to v_proj and o_proj | required |
mlp_in | | list of names or regexes for the mlp blocks that receive the input to the MLP block, usually up_proj and gate_proj | required |
mlp_out | | list of names or regexes for the mlp blocks that consitute the output of the MLP block, usually down_proj | required |
Source code in llmcompressor/modifiers/transform/spinquant/mappings.py
| class SpinQuantMapping(BaseModel):
"""
SpinQuant needs to know the entire architecture of the model,
as R1, R2, R3, and R4 rotations need to be applied to specific
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
:param embedding: name or regex of embedding layer
:param attn_q: name or regex of q_proj layer in attention block
:param attn_k: name or regex of k_proj layer in attention block
:param attn_v: name or regex of v_proj layer in attention block
:param attn_o: name or regex of o_proj layer in attention block
:param attn_head_dim: head_dim of the attention module, needed
because R2 needs to be applied "head-wisely" to v_proj and
o_proj
:param mlp_in: list of names or regexes for the mlp blocks that
receive the input to the MLP block, usually up_proj and gate_proj
:param mlp_out: list of names or regexes for the mlp blocks that
consitute the output of the MLP block, usually down_proj
"""
embedding: str
attn_q: str
attn_k: str
attn_v: str
attn_o: str
attn_head_dim: Optional[int] = Field(default=None)
mlp_in: List[str] # up_proj, gate_proj
mlp_out: List[str] # down_proj
lm_head: str
@field_validator("mlp_in", "mlp_out", mode="before")
def cast_to_list(cls, value):
if isinstance(value, str):
return [value]
return value
|