vllm_omni.diffusion.hooks ¶
Hook mechanism for model forward interception.
Modules:
| Name | Description |
|---|---|
base | Base hook classes for model forward interception. |
sequence_parallel | Sequence Parallelism hooks for non-intrusive SP support. |
HookRegistry ¶
Registry of hooks attached to a module.
Manages multiple hooks that can intercept a module's forward pass. Hooks are called in sorted order by name for determinism.
check_if_exists_or_initialize classmethod ¶
check_if_exists_or_initialize(
module: Module,
) -> HookRegistry
Get existing registry or create a new one for the module.
This method ensures a HookRegistry exists on the module and returns it. If a registry doesn't exist, it creates one and attaches it to the module. This is equivalent to get_or_create() for compatibility with diffusers API.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module to get/create a registry for. | required |
Returns:
| Type | Description |
|---|---|
HookRegistry | The HookRegistry for this module. |
dispatch ¶
Dispatch a forward call through registered hooks.
Multiple hooks may be used with the caveat that only one hook may override new_forward. While it is assumed that pre/post process on hooks are composable, the execution flow is as follows for determinism:
-
Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically, except for the hook overriding forward (
self._new_fwd_impl_hook), which is last if it exists. -
If
self._new_fwd_impl_hookisn't None, call its forward. Otherwise call the original model forward. -
Run post process on all hooks in the reverse sorted order.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args | Any | Positional arguments to forward. | () |
**kwargs | Any | Keyword arguments to forward. | {} |
Returns:
| Type | Description |
|---|---|
Any | The output of the forward pass. |
get_hook ¶
get_or_create classmethod ¶
get_or_create(module: Module) -> HookRegistry
Get existing registry or create a new one for the module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module to get/create a registry for. | required |
Returns:
| Type | Description |
|---|---|
HookRegistry | The HookRegistry for this module. |
register_hook ¶
remove_hook ¶
remove_hook(name: str) -> None
Remove a hook by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | The name of the hook to remove. | required |
reset ¶
Reset all hooks and clear the registry.
This removes all hooks from the registry and resets each hook's state. Also restores module.forward to its original implementation.
reset_hook ¶
reset_hook(name: str) -> None
Reset a hook's state by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | The name of the hook to reset. | required |
update_sorted_hooks ¶
Sort hooks by name, which dictates pre/post process order.
ModelHook ¶
Base class for model hooks that can override a module's forward.
Hooks can intercept the forward pass at two points: - pre_forward: Called before the original forward, can modify args/kwargs - post_forward: Called after the original forward, can modify output
Subclasses can override either or both methods. The default implementations pass through args/kwargs/output unchanged.
For more complex behavior, override new_forward to completely replace the forward logic.
initialize_hook ¶
Initialize the hook when it's registered to a module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module this hook is being attached to. | required |
Returns:
| Type | Description |
|---|---|
Module | The module (possibly modified). |
new_forward ¶
Override the module's forward pass. This should be overridden for more complex cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called instead of self.module._omni_original_forward when executing the hooks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module being called. | required |
*args | Any | Positional arguments to forward. | () |
**kwargs | Any | Keyword arguments to forward. | {} |
Returns:
| Type | Description |
|---|---|
Any | The output of the replacement for the forward pass. |
post_forward ¶
pre_forward ¶
Called before the module's forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module being called. | required |
*args | Any | Positional arguments to forward. | () |
**kwargs | Any | Keyword arguments to forward. | {} |
Returns:
| Type | Description |
|---|---|
tuple[tuple, dict] | Tuple of (args, kwargs) to pass to the forward method. |
reset_state ¶
Reset any state associated with this hook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module this hook is attached to. | required |
Returns:
| Type | Description |
|---|---|
Module | The module. |
SequenceParallelGatherHook ¶
Bases: ModelHook
Hook for gathering outputs after a module's forward pass.
This hook is registered to modules that need their outputs gathered from all sequence parallel ranks. It intercepts the output and gathers it according to the plan specification.
Note: This corresponds to ContextParallelGatherHook in diffusers.
SequenceParallelSplitHook ¶
Bases: ModelHook
Hook for splitting inputs before a module's forward pass.
This hook is registered to modules that need their inputs sharded across sequence parallel ranks. It intercepts the forward call, shards specified inputs according to the plan, and passes the sharded inputs to the original forward.
For split_output=True inputs, it shards the output instead.
Supports both SequenceParallelInput (full split) and SequenceParallelPartialInput (partial split for text/image separation).
Note: This corresponds to ContextParallelSplitHook in diffusers.
module_forward_metadata instance-attribute ¶
module_forward_metadata: ModuleForwardMetadata | None = None
post_forward ¶
Shard outputs for split_output=True entries.
StateManager ¶
apply_sequence_parallel ¶
apply_sequence_parallel(
module: Module,
config: SequenceParallelConfig,
plan: SequenceParallelModelPlan,
) -> None
Apply sequence parallel hooks to a model according to the plan.
This function registers hooks on the specified submodules to automatically shard inputs and gather outputs for sequence parallelism.
Note: This corresponds to apply_context_parallel in diffusers.
The complete SP flow is: 1. Input sharding (SequenceParallelSplitHook): Split sequence across SP ranks 2. Attention parallelism (handled by vLLM-Omni's Attention layer): - Ulysses: All-to-All over Q/K/V heads - Ring: K/V circulation in ring topology - Hybrid: Both (Ulysses handles head redistribution, Ring handles K/V) 3. Output gathering (SequenceParallelGatherHook): Gather sequence from SP ranks
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The model to apply SP to. | required |
config | SequenceParallelConfig | The sequence parallel configuration. | required |
plan | SequenceParallelModelPlan | Dictionary mapping module names to input/output specifications. | required |
Example
config = SequenceParallelConfig(ulysses_degree=2) plan = { "": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)}, "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), } apply_sequence_parallel(model, config, plan)
Note
vLLM-Omni's Attention layer automatically handles the internal parallelism (Ulysses All-to-All or Ring attention) based on the forward_context configuration. This function only handles input/output sharding for the model as a whole.
disable_sequence_parallel_for_model ¶
Disable sequence parallelism for a model.
Note: This corresponds to disable_context_parallel_for_model in diffusers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | The model to disable SP for. | required |
enable_sequence_parallel_for_model ¶
enable_sequence_parallel_for_model(
model: Module,
config: SequenceParallelConfig | None = None,
) -> None
Enable sequence parallelism for a model using its _sp_plan.
This is a convenience function that reads the model's _sp_plan attribute and applies sequence parallelism automatically.
Note: This corresponds to enable_context_parallel_for_model in diffusers, but uses vLLM-Omni's _sp_plan instead of diffusers' _cp_plan.
The function performs two main tasks: 1. Applies _sp_plan hooks to shard inputs and gather outputs 2. Ensures Attention layers are configured for the correct parallel mode (handled automatically by vLLM-Omni's forward_context mechanism)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | The model to enable SP for. Must have a _sp_plan attribute. | required |
config | SequenceParallelConfig | None | Optional config. If None, uses default based on current parallel state. | None |
Raises:
| Type | Description |
|---|---|
ValueError | If model has no _sp_plan defined. |
Note
vLLM-Omni supports Ulysses + Ring hybrid parallelism: - ulysses_degree > 1: Uses All-to-All communication over Q/K/V heads - ring_degree > 1: Uses Ring attention with K/V passing - Both > 1: Hybrid mode (Ulysses handles head redistribution, Ring handles K/V circulation)
remove_sequence_parallel ¶
remove_sequence_parallel(
module: Module, plan: SequenceParallelModelPlan
) -> None
Remove sequence parallel hooks from a model.
Note: This corresponds to remove_context_parallel in diffusers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The model to remove SP from. | required |
plan | SequenceParallelModelPlan | The same plan used when applying SP. | required |