Warm up Qwen Triton kernels from the loaded model's compile keys.
Functions:
qwen_triton_warmup(runner, model_config)
Warm Qwen Triton kernels reported by the JIT monitor.
Source code in vllm/model_executor/warmup/qwen_triton_warmup.py
| @torch.inference_mode()
def qwen_triton_warmup(
runner: "GPUModelRunner",
model_config: object,
) -> None:
"""Warm Qwen Triton kernels reported by the JIT monitor."""
if runner.is_pooling_model:
return
hf_text_config = getattr(model_config, "hf_text_config", None)
hf_config = getattr(model_config, "hf_config", None)
model_type = None
for config in (hf_text_config, hf_config):
model_type = getattr(config, "model_type", None)
if model_type is not None:
model_type = str(model_type)
break
if model_type not in _QWEN_MODEL_TYPES:
return
device = getattr(runner, "device", torch.device("cuda"))
logger.info("Warming up Qwen Triton kernels for model_type=%s.", model_type)
zero_config = _zero_kv_warmup_config(runner)
if _warm_zero_kv_blocks_with_runner_zeroer(runner):
pass
elif zero_config is not None:
_warm_zero_kv_blocks_kernel(device, zero_config)
else:
logger.info("Skipping Qwen zero-kv warmup: no KVBlockZeroer metadata.")
_warm_compute_slot_mapping_kernel(device)
_synchronize_device(device)
compilation_config = getattr(runner, "compilation_config", None)
static_forward_context = getattr(compilation_config, "static_forward_context", None)
gdn_config = _qwen_gdn_warmup_config(static_forward_context)
if gdn_config is None:
return
_warm_causal_conv1d_fwd_kernel(device, gdn_config)
_warm_fused_post_conv_kernel(device, gdn_config)
_warm_fused_sigmoid_gating_delta_rule_update_kernel(device, gdn_config)
_synchronize_device(device)
|