Skip to content

vllm_omni.diffusion.hooks.base

Base hook classes for model forward interception.

This module provides the foundational hook mechanism that allows intercepting and modifying model forward passes without invasive changes to model code.

logger module-attribute

logger = init_logger(__name__)

BaseState

Base class for hook state containers.

reset

reset() -> None

HookRegistry

Registry of hooks attached to a module.

Manages multiple hooks that can intercept a module's forward pass. Hooks are called in sorted order by name for determinism.

module instance-attribute

module = module

check_if_exists_or_initialize classmethod

check_if_exists_or_initialize(
    module: Module,
) -> HookRegistry

Get existing registry or create a new one for the module.

This method ensures a HookRegistry exists on the module and returns it. If a registry doesn't exist, it creates one and attaches it to the module. This is equivalent to get_or_create() for compatibility with diffusers API.

Parameters:

Name Type Description Default
module Module

The module to get/create a registry for.

required

Returns:

Type Description
HookRegistry

The HookRegistry for this module.

dispatch

dispatch(*args: Any, **kwargs: Any) -> Any

Dispatch a forward call through registered hooks.

Multiple hooks may be used with the caveat that only one hook may override new_forward. While it is assumed that pre/post process on hooks are composable, the execution flow is as follows for determinism:

  • Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically, except for the hook overriding forward (self._new_fwd_impl_hook), which is last if it exists.

  • If self._new_fwd_impl_hook isn't None, call its forward. Otherwise call the original model forward.

  • Run post process on all hooks in the reverse sorted order.

Parameters:

Name Type Description Default
*args Any

Positional arguments to forward.

()
**kwargs Any

Keyword arguments to forward.

{}

Returns:

Type Description
Any

The output of the forward pass.

get_hook

get_hook(name: str) -> ModelHook | None

Get a hook by name.

Parameters:

Name Type Description Default
name str

The name of the hook.

required

Returns:

Type Description
ModelHook | None

The hook if found, None otherwise.

get_or_create classmethod

get_or_create(module: Module) -> HookRegistry

Get existing registry or create a new one for the module.

Parameters:

Name Type Description Default
module Module

The module to get/create a registry for.

required

Returns:

Type Description
HookRegistry

The HookRegistry for this module.

register_hook

register_hook(name: str, hook: ModelHook) -> None

Register a hook with the given name.

Parameters:

Name Type Description Default
name str

Unique name for this hook.

required
hook ModelHook

The hook instance to register.

required

remove_hook

remove_hook(name: str) -> None

Remove a hook by name.

Parameters:

Name Type Description Default
name str

The name of the hook to remove.

required

reset

reset() -> None

Reset all hooks and clear the registry.

This removes all hooks from the registry and resets each hook's state. Also restores module.forward to its original implementation.

reset_hook

reset_hook(name: str) -> None

Reset a hook's state by name.

Parameters:

Name Type Description Default
name str

The name of the hook to reset.

required

update_sorted_hooks

update_sorted_hooks()

Sort hooks by name, which dictates pre/post process order.

ModelHook

Base class for model hooks that can override a module's forward.

Hooks can intercept the forward pass at two points: - pre_forward: Called before the original forward, can modify args/kwargs - post_forward: Called after the original forward, can modify output

Subclasses can override either or both methods. The default implementations pass through args/kwargs/output unchanged.

For more complex behavior, override new_forward to completely replace the forward logic.

initialize_hook

initialize_hook(module: Module) -> Module

Initialize the hook when it's registered to a module.

Parameters:

Name Type Description Default
module Module

The module this hook is being attached to.

required

Returns:

Type Description
Module

The module (possibly modified).

new_forward

new_forward(
    module: Module, *args: Any, **kwargs: Any
) -> Any

Override the module's forward pass. This should be overridden for more complex cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called instead of self.module._omni_original_forward when executing the hooks.

Parameters:

Name Type Description Default
module Module

The module being called.

required
*args Any

Positional arguments to forward.

()
**kwargs Any

Keyword arguments to forward.

{}

Returns:

Type Description
Any

The output of the replacement for the forward pass.

post_forward

post_forward(module: Module, output: Any) -> Any

Called after the module's forward pass.

Parameters:

Name Type Description Default
module Module

The module that was called.

required
output Any

The output from the forward method.

required

Returns:

Type Description
Any

The (possibly modified) output.

pre_forward

pre_forward(
    module: Module, *args: Any, **kwargs: Any
) -> tuple[tuple, dict]

Called before the module's forward pass.

Parameters:

Name Type Description Default
module Module

The module being called.

required
*args Any

Positional arguments to forward.

()
**kwargs Any

Keyword arguments to forward.

{}

Returns:

Type Description
tuple[tuple, dict]

Tuple of (args, kwargs) to pass to the forward method.

reset_state

reset_state(module: Module) -> Module

Reset any state associated with this hook.

Parameters:

Name Type Description Default
module Module

The module this hook is attached to.

required

Returns:

Type Description
Module

The module.

StateManager

Manage per-context hook state instances.

get_state

get_state() -> BaseState

reset

reset() -> None

set_context

set_context(name: str) -> None

sort_hooks_after_call

sort_hooks_after_call(func)

Calls the method on the hook registry, then sorts the hooks.

This should be added to methods that mutate add or remove hooks.