Skip to content

vllm.model_executor.layers.fused_qk_norm_rope

Fused QK-RMSNorm + (partial) RoPE + gate copy Triton kernel.

Currently used by the Qwen3.5 attention path (attn_output_gate with NeoX-style partial RoPE). The unfused reference sequence is split -> GemmaRMSNorm -> RoPE -> gate chunk; this collapses it into a single Triton launch. See :func:fused_qk_rmsnorm_rope_gate.

Functions:

fused_qk_rmsnorm_rope_gate(q_gate, k, q_weight, k_weight, cos_sin_cache, positions, eps, num_q_heads, num_kv_heads, head_dim, rotary_dim)

Fused split + QK-RMSNorm + (partial) RoPE + gate copy for Qwen3.5 attn.

Parameters:

  • q_gate

    (Tensor) –

    (n_tokens, num_q_heads * 2 * head_dim) -- per head: [q|gate]

  • k

    (Tensor) –

    (n_tokens, num_kv_heads * head_dim)

  • q_weight

    (Tensor) –

    (head_dim,) GemmaRMSNorm effective weight (already +1)

  • k_weight

    (Tensor) –

    (head_dim,) GemmaRMSNorm effective weight (already +1)

  • cos_sin_cache

    (Tensor) –

    (max_pos, rotary_dim) packed [cos|sin]

  • positions

    (Tensor) –

    (n_tokens,) int32 or int64

  • eps

    (float) –

    RMSNorm epsilon

  • num_q_heads

    (int) –

    number of Q heads (after TP split)

  • num_kv_heads

    (int) –

    number of KV heads (after TP split)

  • head_dim

    (int) –

    per-head dimension

  • rotary_dim

    (int) –

    rotary dimension; must be even and <= head_dim

Returns:

  • Tensor

    (q_out, k_out, gate_out) -- all contiguous (n_tokens, heads * head_dim).

  • Tensor

    gate_out is the raw (pre-sigmoid) gate.

Source code in vllm/model_executor/layers/fused_qk_norm_rope.py
def fused_qk_rmsnorm_rope_gate(
    q_gate: torch.Tensor,
    k: torch.Tensor,
    q_weight: torch.Tensor,
    k_weight: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    positions: torch.Tensor,
    eps: float,
    num_q_heads: int,
    num_kv_heads: int,
    head_dim: int,
    rotary_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Fused split + QK-RMSNorm + (partial) RoPE + gate copy for Qwen3.5 attn.

    Args:
        q_gate: (n_tokens, num_q_heads * 2 * head_dim) -- per head: [q|gate]
        k: (n_tokens, num_kv_heads * head_dim)
        q_weight: (head_dim,) GemmaRMSNorm effective weight (already +1)
        k_weight: (head_dim,) GemmaRMSNorm effective weight (already +1)
        cos_sin_cache: (max_pos, rotary_dim) packed [cos|sin]
        positions: (n_tokens,) int32 or int64
        eps: RMSNorm epsilon
        num_q_heads: number of Q heads (after TP split)
        num_kv_heads: number of KV heads (after TP split)
        head_dim: per-head dimension
        rotary_dim: rotary dimension; must be even and <= head_dim

    Returns:
        (q_out, k_out, gate_out) -- all contiguous (n_tokens, heads * head_dim).
        ``gate_out`` is the raw (pre-sigmoid) gate.
    """
    if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0:
        raise ValueError(
            f"rotary_dim must be a positive even integer <= head_dim, "
            f"got rotary_dim={rotary_dim}, head_dim={head_dim}"
        )

    n_tokens = q_gate.shape[0]
    q_out = torch.empty(
        (n_tokens, num_q_heads * head_dim), dtype=q_gate.dtype, device=q_gate.device
    )
    k_out = torch.empty(
        (n_tokens, num_kv_heads * head_dim), dtype=k.dtype, device=k.device
    )
    gate_out = torch.empty_like(q_out)
    if n_tokens == 0:
        return q_out, k_out, gate_out

    half_rotary = rotary_dim // 2
    head_block = triton.next_power_of_2(head_dim)
    rot_half_block = triton.next_power_of_2(half_rotary)
    num_warps = max(1, head_block // 64)

    grid = (n_tokens, num_q_heads + num_kv_heads)
    _fused_qk_rmsnorm_rope_gate_kernel[grid](
        q_gate,
        k,
        q_out,
        k_out,
        gate_out,
        q_weight,
        k_weight,
        cos_sin_cache,
        positions,
        q_gate.stride(0),
        k.stride(0),
        q_out.stride(0),
        k_out.stride(0),
        gate_out.stride(0),
        cos_sin_cache.stride(0),
        num_q_heads,
        num_kv_heads,
        head_dim,
        rotary_dim,
        half_rotary,
        eps,
        INPUT_DTYPE=tl.bfloat16 if q_gate.dtype == torch.bfloat16 else tl.float16,
        HEAD_BLOCK=head_block,
        ROT_HALF_BLOCK=rot_half_block,
        HAS_PASS=rotary_dim < head_dim,
        num_warps=num_warps,
        num_stages=2,
    )
    return q_out, k_out, gate_out