Skip to content

vllm_omni.diffusion.hooks

Hook mechanism for model forward interception.

Modules:

Name Description
base

Base hook classes for model forward interception.

sequence_parallel

Sequence Parallelism hooks for non-intrusive SP support.

BaseState

Base class for hook state containers.

reset

reset() -> None

HookRegistry

Registry of hooks attached to a module.

Manages multiple hooks that can intercept a module's forward pass. Hooks are called in sorted order by name for determinism.

module instance-attribute

module = module

check_if_exists_or_initialize classmethod

check_if_exists_or_initialize(
    module: Module,
) -> HookRegistry

Get existing registry or create a new one for the module.

This method ensures a HookRegistry exists on the module and returns it. If a registry doesn't exist, it creates one and attaches it to the module. This is equivalent to get_or_create() for compatibility with diffusers API.

Parameters:

Name Type Description Default
module Module

The module to get/create a registry for.

required

Returns:

Type Description
HookRegistry

The HookRegistry for this module.

dispatch

dispatch(*args: Any, **kwargs: Any) -> Any

Dispatch a forward call through registered hooks.

Multiple hooks may be used with the caveat that only one hook may override new_forward. While it is assumed that pre/post process on hooks are composable, the execution flow is as follows for determinism:

  • Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically, except for the hook overriding forward (self._new_fwd_impl_hook), which is last if it exists.

  • If self._new_fwd_impl_hook isn't None, call its forward. Otherwise call the original model forward.

  • Run post process on all hooks in the reverse sorted order.

Parameters:

Name Type Description Default
*args Any

Positional arguments to forward.

()
**kwargs Any

Keyword arguments to forward.

{}

Returns:

Type Description
Any

The output of the forward pass.

get_hook

get_hook(name: str) -> ModelHook | None

Get a hook by name.

Parameters:

Name Type Description Default
name str

The name of the hook.

required

Returns:

Type Description
ModelHook | None

The hook if found, None otherwise.

get_or_create classmethod

get_or_create(module: Module) -> HookRegistry

Get existing registry or create a new one for the module.

Parameters:

Name Type Description Default
module Module

The module to get/create a registry for.

required

Returns:

Type Description
HookRegistry

The HookRegistry for this module.

register_hook

register_hook(name: str, hook: ModelHook) -> None

Register a hook with the given name.

Parameters:

Name Type Description Default
name str

Unique name for this hook.

required
hook ModelHook

The hook instance to register.

required

remove_hook

remove_hook(name: str) -> None

Remove a hook by name.

Parameters:

Name Type Description Default
name str

The name of the hook to remove.

required

reset

reset() -> None

Reset all hooks and clear the registry.

This removes all hooks from the registry and resets each hook's state. Also restores module.forward to its original implementation.

reset_hook

reset_hook(name: str) -> None

Reset a hook's state by name.

Parameters:

Name Type Description Default
name str

The name of the hook to reset.

required

update_sorted_hooks

update_sorted_hooks()

Sort hooks by name, which dictates pre/post process order.

ModelHook

Base class for model hooks that can override a module's forward.

Hooks can intercept the forward pass at two points: - pre_forward: Called before the original forward, can modify args/kwargs - post_forward: Called after the original forward, can modify output

Subclasses can override either or both methods. The default implementations pass through args/kwargs/output unchanged.

For more complex behavior, override new_forward to completely replace the forward logic.

initialize_hook

initialize_hook(module: Module) -> Module

Initialize the hook when it's registered to a module.

Parameters:

Name Type Description Default
module Module

The module this hook is being attached to.

required

Returns:

Type Description
Module

The module (possibly modified).

new_forward

new_forward(
    module: Module, *args: Any, **kwargs: Any
) -> Any

Override the module's forward pass. This should be overridden for more complex cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called instead of self.module._omni_original_forward when executing the hooks.

Parameters:

Name Type Description Default
module Module

The module being called.

required
*args Any

Positional arguments to forward.

()
**kwargs Any

Keyword arguments to forward.

{}

Returns:

Type Description
Any

The output of the replacement for the forward pass.

post_forward

post_forward(module: Module, output: Any) -> Any

Called after the module's forward pass.

Parameters:

Name Type Description Default
module Module

The module that was called.

required
output Any

The output from the forward method.

required

Returns:

Type Description
Any

The (possibly modified) output.

pre_forward

pre_forward(
    module: Module, *args: Any, **kwargs: Any
) -> tuple[tuple, dict]

Called before the module's forward pass.

Parameters:

Name Type Description Default
module Module

The module being called.

required
*args Any

Positional arguments to forward.

()
**kwargs Any

Keyword arguments to forward.

{}

Returns:

Type Description
tuple[tuple, dict]

Tuple of (args, kwargs) to pass to the forward method.

reset_state

reset_state(module: Module) -> Module

Reset any state associated with this hook.

Parameters:

Name Type Description Default
module Module

The module this hook is attached to.

required

Returns:

Type Description
Module

The module.

SequenceParallelGatherHook

Bases: ModelHook

Hook for gathering outputs after a module's forward pass.

This hook is registered to modules that need their outputs gathered from all sequence parallel ranks. It intercepts the output and gathers it according to the plan specification.

Note: This corresponds to ContextParallelGatherHook in diffusers.

config instance-attribute

config = config

metadata instance-attribute

metadata = metadata

initialize_hook

initialize_hook(module: Module) -> Module

post_forward

post_forward(module: Module, output: Any) -> Any

Gather outputs after forward and remove padding if applied.

SequenceParallelSplitHook

Bases: ModelHook

Hook for splitting inputs before a module's forward pass.

This hook is registered to modules that need their inputs sharded across sequence parallel ranks. It intercepts the forward call, shards specified inputs according to the plan, and passes the sharded inputs to the original forward.

For split_output=True inputs, it shards the output instead.

Supports both SequenceParallelInput (full split) and SequenceParallelPartialInput (partial split for text/image separation).

Note: This corresponds to ContextParallelSplitHook in diffusers.

config instance-attribute

config = config

metadata instance-attribute

metadata = metadata

module_forward_metadata instance-attribute

module_forward_metadata: ModuleForwardMetadata | None = None

initialize_hook

initialize_hook(module: Module) -> Module

post_forward

post_forward(module: Module, output: Any) -> Any

Shard outputs for split_output=True entries.

pre_forward

pre_forward(
    module: Module, *args: Any, **kwargs: Any
) -> tuple[tuple, dict]

Shard inputs before forward.

StateManager

Manage per-context hook state instances.

get_state

get_state() -> BaseState

reset

reset() -> None

set_context

set_context(name: str) -> None

apply_sequence_parallel

apply_sequence_parallel(
    module: Module,
    config: SequenceParallelConfig,
    plan: SequenceParallelModelPlan,
) -> None

Apply sequence parallel hooks to a model according to the plan.

This function registers hooks on the specified submodules to automatically shard inputs and gather outputs for sequence parallelism.

Note: This corresponds to apply_context_parallel in diffusers.

The complete SP flow is: 1. Input sharding (SequenceParallelSplitHook): Split sequence across SP ranks 2. Attention parallelism (handled by vLLM-Omni's Attention layer): - Ulysses: All-to-All over Q/K/V heads - Ring: K/V circulation in ring topology - Hybrid: Both (Ulysses handles head redistribution, Ring handles K/V) 3. Output gathering (SequenceParallelGatherHook): Gather sequence from SP ranks

Parameters:

Name Type Description Default
module Module

The model to apply SP to.

required
config SequenceParallelConfig

The sequence parallel configuration.

required
plan SequenceParallelModelPlan

Dictionary mapping module names to input/output specifications.

required
Example

config = SequenceParallelConfig(ulysses_degree=2) plan = { "": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)}, "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), } apply_sequence_parallel(model, config, plan)

Note

vLLM-Omni's Attention layer automatically handles the internal parallelism (Ulysses All-to-All or Ring attention) based on the forward_context configuration. This function only handles input/output sharding for the model as a whole.

disable_sequence_parallel_for_model

disable_sequence_parallel_for_model(model: Module) -> None

Disable sequence parallelism for a model.

Note: This corresponds to disable_context_parallel_for_model in diffusers.

Parameters:

Name Type Description Default
model Module

The model to disable SP for.

required

enable_sequence_parallel_for_model

enable_sequence_parallel_for_model(
    model: Module,
    config: SequenceParallelConfig | None = None,
) -> None

Enable sequence parallelism for a model using its _sp_plan.

This is a convenience function that reads the model's _sp_plan attribute and applies sequence parallelism automatically.

Note: This corresponds to enable_context_parallel_for_model in diffusers, but uses vLLM-Omni's _sp_plan instead of diffusers' _cp_plan.

The function performs two main tasks: 1. Applies _sp_plan hooks to shard inputs and gather outputs 2. Ensures Attention layers are configured for the correct parallel mode (handled automatically by vLLM-Omni's forward_context mechanism)

Parameters:

Name Type Description Default
model Module

The model to enable SP for. Must have a _sp_plan attribute.

required
config SequenceParallelConfig | None

Optional config. If None, uses default based on current parallel state.

None

Raises:

Type Description
ValueError

If model has no _sp_plan defined.

Note

vLLM-Omni supports Ulysses + Ring hybrid parallelism: - ulysses_degree > 1: Uses All-to-All communication over Q/K/V heads - ring_degree > 1: Uses Ring attention with K/V passing - Both > 1: Hybrid mode (Ulysses handles head redistribution, Ring handles K/V circulation)

remove_sequence_parallel

remove_sequence_parallel(
    module: Module, plan: SequenceParallelModelPlan
) -> None

Remove sequence parallel hooks from a model.

Note: This corresponds to remove_context_parallel in diffusers.

Parameters:

Name Type Description Default
module Module

The model to remove SP from.

required
plan SequenceParallelModelPlan

The same plan used when applying SP.

required