Wrap torch inductor/aot_autograd compile entry points so the synchronizing ops those passes perform don't trip the sync-check mode we set around execute_model / sample_tokens.
Warmup-time compiles already run under the gate (before enable_gpu_sync_check), but post-warmup compiles fire inside execute_model and we want to avoid this tripping the sync check.
Source code in vllm/utils/gpu_sync_debug.py
| def _install_compile_time_sync_suppressors() -> None:
"""Wrap torch inductor/aot_autograd compile entry points so the
synchronizing ops those passes perform don't trip the
sync-check mode we set around `execute_model` / `sample_tokens`.
Warmup-time compiles already run under the gate (before
`enable_gpu_sync_check`), but post-warmup compiles fire inside
`execute_model` and we want to avoid this tripping the sync check.
"""
global _compile_time_suppressors_installed
if _compile_time_suppressors_installed:
return
_compile_time_suppressors_installed = True
try: # noqa: BLE001
from torch._inductor.fx_passes import joint_graph as _jg
_orig_joint = _jg.joint_graph_passes
@functools.wraps(_orig_joint)
def _wrapped_joint(*args, **kwargs):
prev_mode = torch.cuda.get_sync_debug_mode()
if not prev_mode:
return _orig_joint(*args, **kwargs)
torch.cuda.set_sync_debug_mode(0)
try:
return _orig_joint(*args, **kwargs)
finally:
torch.cuda.set_sync_debug_mode(prev_mode)
# `compile_fx` does `from .fx_passes.joint_graph import
# joint_graph_passes`, which binds the *function object* at import
# time. Patching just the module attribute won't update that rebind,
# so patch every already-imported reference we can find. Restrict
# the scan to torch's compile-time modules.
import sys as _sys
setattr(_jg, "joint_graph_passes", _wrapped_joint) # noqa: B010
for _name, _mod in list(_sys.modules.items()):
if _mod is None:
continue
if not (
_name.startswith("torch._inductor")
or _name.startswith("torch._functorch")
or _name.startswith("torch._dynamo")
):
continue
if getattr(_mod, "joint_graph_passes", None) is _orig_joint:
setattr(_mod, "joint_graph_passes", _wrapped_joint) # noqa: B010
except Exception: # pragma: no cover
pass
|