Skip to content

vllm.model_executor.warmup.qwen_triton_warmup

Warm up Qwen Triton kernels from the loaded model's compile keys.

Functions:

qwen_triton_warmup(runner, model_config)

Warm Qwen Triton kernels reported by the JIT monitor.

Source code in vllm/model_executor/warmup/qwen_triton_warmup.py
@torch.inference_mode()
def qwen_triton_warmup(
    runner: "GPUModelRunner",
    model_config: object,
) -> None:
    """Warm Qwen Triton kernels reported by the JIT monitor."""
    if runner.is_pooling_model:
        return

    hf_text_config = getattr(model_config, "hf_text_config", None)
    hf_config = getattr(model_config, "hf_config", None)
    model_type = None
    for config in (hf_text_config, hf_config):
        model_type = getattr(config, "model_type", None)
        if model_type is not None:
            model_type = str(model_type)
            break
    if model_type not in _QWEN_MODEL_TYPES:
        return

    device = getattr(runner, "device", torch.device("cuda"))
    logger.info("Warming up Qwen Triton kernels for model_type=%s.", model_type)

    zero_config = _zero_kv_warmup_config(runner)
    if _warm_zero_kv_blocks_with_runner_zeroer(runner):
        pass
    elif zero_config is not None:
        _warm_zero_kv_blocks_kernel(device, zero_config)
    else:
        logger.info("Skipping Qwen zero-kv warmup: no KVBlockZeroer metadata.")

    _warm_compute_slot_mapping_kernel(device)
    _synchronize_device(device)

    compilation_config = getattr(runner, "compilation_config", None)
    static_forward_context = getattr(compilation_config, "static_forward_context", None)
    gdn_config = _qwen_gdn_warmup_config(static_forward_context)
    if gdn_config is None:
        return

    _warm_causal_conv1d_fwd_kernel(device, gdn_config)
    _warm_fused_post_conv_kernel(device, gdn_config)
    _warm_fused_sigmoid_gating_delta_rule_update_kernel(device, gdn_config)
    _synchronize_device(device)