vllm_omni.diffusion.hooks.base ¶
Base hook classes for model forward interception.
This module provides the foundational hook mechanism that allows intercepting and modifying model forward passes without invasive changes to model code.
HookRegistry ¶
Registry of hooks attached to a module.
Manages multiple hooks that can intercept a module's forward pass. Hooks are called in sorted order by name for determinism.
check_if_exists_or_initialize classmethod ¶
check_if_exists_or_initialize(
module: Module,
) -> HookRegistry
Get existing registry or create a new one for the module.
This method ensures a HookRegistry exists on the module and returns it. If a registry doesn't exist, it creates one and attaches it to the module. This is equivalent to get_or_create() for compatibility with diffusers API.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module to get/create a registry for. | required |
Returns:
| Type | Description |
|---|---|
HookRegistry | The HookRegistry for this module. |
dispatch ¶
Dispatch a forward call through registered hooks.
Multiple hooks may be used with the caveat that only one hook may override new_forward. While it is assumed that pre/post process on hooks are composable, the execution flow is as follows for determinism:
-
Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically, except for the hook overriding forward (
self._new_fwd_impl_hook), which is last if it exists. -
If
self._new_fwd_impl_hookisn't None, call its forward. Otherwise call the original model forward. -
Run post process on all hooks in the reverse sorted order.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args | Any | Positional arguments to forward. | () |
**kwargs | Any | Keyword arguments to forward. | {} |
Returns:
| Type | Description |
|---|---|
Any | The output of the forward pass. |
get_hook ¶
get_or_create classmethod ¶
get_or_create(module: Module) -> HookRegistry
Get existing registry or create a new one for the module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module to get/create a registry for. | required |
Returns:
| Type | Description |
|---|---|
HookRegistry | The HookRegistry for this module. |
register_hook ¶
remove_hook ¶
remove_hook(name: str) -> None
Remove a hook by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | The name of the hook to remove. | required |
reset ¶
Reset all hooks and clear the registry.
This removes all hooks from the registry and resets each hook's state. Also restores module.forward to its original implementation.
reset_hook ¶
reset_hook(name: str) -> None
Reset a hook's state by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | The name of the hook to reset. | required |
update_sorted_hooks ¶
Sort hooks by name, which dictates pre/post process order.
ModelHook ¶
Base class for model hooks that can override a module's forward.
Hooks can intercept the forward pass at two points: - pre_forward: Called before the original forward, can modify args/kwargs - post_forward: Called after the original forward, can modify output
Subclasses can override either or both methods. The default implementations pass through args/kwargs/output unchanged.
For more complex behavior, override new_forward to completely replace the forward logic.
initialize_hook ¶
Initialize the hook when it's registered to a module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module this hook is being attached to. | required |
Returns:
| Type | Description |
|---|---|
Module | The module (possibly modified). |
new_forward ¶
Override the module's forward pass. This should be overridden for more complex cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called instead of self.module._omni_original_forward when executing the hooks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module being called. | required |
*args | Any | Positional arguments to forward. | () |
**kwargs | Any | Keyword arguments to forward. | {} |
Returns:
| Type | Description |
|---|---|
Any | The output of the replacement for the forward pass. |
post_forward ¶
pre_forward ¶
Called before the module's forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module being called. | required |
*args | Any | Positional arguments to forward. | () |
**kwargs | Any | Keyword arguments to forward. | {} |
Returns:
| Type | Description |
|---|---|
tuple[tuple, dict] | Tuple of (args, kwargs) to pass to the forward method. |
reset_state ¶
Reset any state associated with this hook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module this hook is attached to. | required |
Returns:
| Type | Description |
|---|---|
Module | The module. |
StateManager ¶
sort_hooks_after_call ¶
Calls the method on the hook registry, then sorts the hooks.
This should be added to methods that mutate add or remove hooks.