CFG-Parallel¶
This section describes how to add CFG-Parallel (Classifier-Free Guidance Parallel) to a diffusion pipeline. We use the Qwen-Image pipeline as the reference implementation.
Table of Contents¶
- Overview
- Step-by-Step Implementation
- Customization
- Testing
- Troubleshooting
- Reference Implementations
- Summary
Overview¶
What is CFG-Parallel?¶
In standard Classifier-Free Guidance, each diffusion step requires two forward passes through the transformer:
- Positive/Conditional: Guided by the text prompt
- Negative/Unconditional: Typically using empty or negative prompt
Some models require 3 or more CFG branches (see N-Branch CFG).
CFG-Parallel eliminates this bottleneck by distributing the forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially.
Architecture¶
vLLM-omni provides CFGParallelMixin that encapsulates all CFG parallel logic. Pipelines inherit from this mixin and implement a diffuse() method that orchestrates the denoising loop.
| Method | Purpose | Automatic Behavior |
|---|---|---|
predict_noise_maybe_with_cfg() | Predict noise with 2-branch CFG | Detects parallel mode, distributes computation, gathers results |
predict_noise_with_multi_branch_cfg() | Predict noise with N-branch CFG | Round-robin dispatches N branches across M GPUs |
scheduler_step_maybe_with_cfg() | Step scheduler | All ranks step locally (no broadcast needed) |
combine_cfg_noise() | Combine 2-branch predictions | Applies CFG formula with optional normalization |
combine_multi_branch_cfg_noise() | Combine N-branch predictions | Override for custom multi-branch combine logic |
predict_noise() | Forward pass wrapper | Override for custom transformer calls |
cfg_normalize_function() | Normalize CFG output | Override for custom normalization |
How It Works¶
predict_noise_maybe_with_cfg() automatically detects and switches between two execution modes:
- CFG-Parallel mode (when
cfg_world_size > 1): - Rank 0 computes positive prompt prediction
- Rank 1 computes negative prompt prediction
- Results are gathered via
all_gather() -
All ranks compute CFG combine locally (deterministic, identical results)
-
Sequential mode (when
cfg_world_size == 1): - Single rank computes both positive and negative predictions
- Directly combines them with CFG formula
scheduler_step_maybe_with_cfg() ensures consistent latent states across all ranks:
- All ranks compute the scheduler step locally — no broadcast needed because
predict_noise_maybe_with_cfgalready ensures all ranks have identical noise predictions afterall_gather+ local combine.
N-Branch CFG (3+ branches)¶
Some models require more than 2 CFG branches. For example, Bagel and OmniGen2 use 3 branches, DreamID Omni uses 4 branches.
predict_noise_with_multi_branch_cfg() handles these by automatically dispatching N branches across M GPUs using round-robin (rule: branch i → rank i % M):
| Branches (N) | GPUs (M) | Dispatch |
|---|---|---|
| 3 | 2 | [[0, 2], [1]] |
| 3 | 3 | [[0], [1], [2]] |
| 4 | 2 | [[0, 2], [1, 3]] |
| 4 | 3 | [[0, 3], [1], [2]] |
| 4 | 4 | [[0], [1], [2], [3]] |
When a rank handles multiple branches, it runs them sequentially. After all_gather, all ranks execute combine_multi_branch_cfg_noise() locally, producing identical results.
Step-by-Step Implementation¶
Step 1: Inherit CFGParallelMixin¶
Allow your pipeline to inherit from CFGParallelMixin and implements the diffuse() method for your specific model.
Example (Qwen-Image):
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
import torch.nn as nn
class YourModelPipeline(nn.Module, CFGParallelMixin):
def diffuse(self, ...) -> torch.Tensor:
for i, t in enumerate(timesteps):
# Prepare positive_kwargs (conditional) and negative_kwargs (unconditional)
positive_kwargs = {...} # hidden_states, encoder_hidden_states, etc.
negative_kwargs = {...} if do_true_cfg else None
# Key method 1: Predict noise with automatic CFG parallel handling
noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=do_true_cfg,
true_cfg_scale=true_cfg_scale,
positive_kwargs=positive_kwargs,
negative_kwargs=negative_kwargs,
)
# Key method 2: Step scheduler with automatic CFG synchronization
latents = self.scheduler_step_maybe_with_cfg(
noise_pred, t, latents, do_true_cfg
)
return latents
Key Points:
positive_kwargs: transformer arguments for conditional (text-guided) predictionnegative_kwargs: transformer arguments for unconditional prediction (set toNoneif CFG disabled)- For image editing pipelines, add
output_slice=image_seq_lento extract the generative image portion - For models with 3+ CFG branches, see Multi-Branch CFG in the Customization section
Step 2: Call diffuse¶
Call self.diffuse in your pipeline's forward function:
import torch.nn as nn
class YourModelPipeline(nn.Module, CFGParallelMixin):
def forward(
self,
prompt: str,
negative_prompt: str | None = None,
guidance_scale: float = 3.5,
num_inference_steps: int = 50,
**kwargs,
):
# Encode prompts, Initialize latents, Get timesteps
...
# Run diffusion loop (calls the mixin's diffuse method)
latents = self.diffuse(
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds=negative_embeds,
negative_prompt_embeds_mask=negative_mask,
latents=latents,
timesteps=timesteps,
do_true_cfg=do_true_cfg,
true_cfg_scale=guidance_scale,
...
)
Customization¶
Override predict_noise() for Custom Transformer Calls¶
If your transformer requires custom prediction function, you can rewrite predict_noise function. Taking Wan2.2 as an example, which has two transformer models. The actual transformer to be called is determined by self.transformer.
class Wan22Pipeline(nn.Module, CFGParallelMixin):
def predict_noise(self, current_model: nn.Module | None = None, **kwargs: Any) -> torch.Tensor:
if current_model is None:
current_model = self.transformer
return current_model(**kwargs)[0]
Override cfg_normalize_function() for Custom Normalization¶
Some models have their own normalization function. Taking LongCat Image model as an example:
class LongCatImagePipeline(nn.Module, CFGParallelMixin):
def cfg_normalize_function(self, noise_pred, comb_pred, cfg_renorm_min=0.0):
"""
Normalize the combined noise prediction.
"""
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
noise_pred = comb_pred * scale
return noise_pred
# The original cfg_normalize_function function in CFGParallelMixin
# cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
# noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
# noise_pred = comb_pred * (cond_norm / noise_norm)
# return noise_pred
Multi-Branch CFG (3+ branches)¶
For models with 3 or more CFG branches, use predict_noise_with_multi_branch_cfg() instead of predict_noise_maybe_with_cfg(), and override combine_multi_branch_cfg_noise() for custom combine logic. This interface also works for standard 2-branch CFG — just pass 2 branches in branches_kwargs.
Example (3-branch with dual guidance scale):
class YourMultiBranchPipeline(nn.Module, CFGParallelMixin):
def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
text_scale = true_cfg_scale["text"]
image_scale = true_cfg_scale["image"]
pos, ref, uncond = predictions
return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
def diffuse(self, ...):
for i, t in enumerate(timesteps):
positive_kwargs = {...} # conditional prompt
ref_neg_kwargs = {...} # negative prompt + reference
uncond_kwargs = {...} # unconditional
noise_pred = self.predict_noise_with_multi_branch_cfg(
do_true_cfg=do_true_cfg,
true_cfg_scale={"text": text_guidance_scale, "image": image_guidance_scale},
branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs],
)
latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
return latents
Override Combine Functions¶
There are two combine functions for different scenarios:
combine_cfg_noise()— Used bypredict_noise_maybe_with_cfg(). Override whenpredict_noise()returns a tuple (e.g., video + audio) and you need per-element CFG logic.combine_multi_branch_cfg_noise()— Used bypredict_noise_with_multi_branch_cfg(). Override to implement custom multi-branch combine formulas (see Multi-Branch CFG above).
Implement a Composite Scheduler for Multi-Output Models¶
When each output has its own denoising schedule, implement a composite scheduler that dispatches to per-output schedulers. Assign it to self.scheduler so the default scheduler_step() works without override.
Complete example (video + audio with separate schedulers and diffuse loop):
class VideoAudioScheduler:
"""Composite scheduler dispatching to video and audio schedulers."""
def __init__(self, video_scheduler, audio_scheduler):
self.video_scheduler = video_scheduler
self.audio_scheduler = audio_scheduler
def step(self, noise_pred, t, latents, return_dict=False, generator=None):
video_out = self.video_scheduler.step(noise_pred[0], t[0], latents[0], return_dict=False, generator=generator)[0]
audio_out = self.audio_scheduler.step(noise_pred[1], t[1], latents[1], return_dict=False, generator=generator)[0]
return ((video_out, audio_out),)
class MyVideoAudioPipeline(nn.Module, CFGParallelMixin):
def __init__(self, ...):
self.scheduler = VideoAudioScheduler(video_sched, audio_sched)
def predict_noise(self, **kwargs):
video_pred, audio_pred = self.transformer(**kwargs)
return (video_pred, audio_pred)
def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize):
# ... (as above)
def diffuse(self, video_latents, audio_latents, timesteps_video, timesteps_audio, ...):
for t_v, t_a in zip(timesteps_video, timesteps_audio):
positive_kwargs = {...}
negative_kwargs = {...} if do_true_cfg else None
video_pred, audio_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=do_true_cfg, true_cfg_scale=self.guidance_scale,
positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs,
)
video_latents, audio_latents = self.scheduler_step_maybe_with_cfg(
(video_pred, audio_pred), (t_v, t_a),
(video_latents, audio_latents), do_true_cfg=do_true_cfg,
generator=generator,
)
return video_latents, audio_latents
Note: If you use a non-deterministic scheduler, e.g., DDPM, please set
self.scheduler_step_maybe_with_cfg(..., generator=torch.Generator(device).manual_seed(seed))explicitly to control the randomness of scheduler step among ranks.
Testing¶
After adding CFG-Parallel support, test with:
cd examples/offline_inference/text_to_image
python text_to_image.py \
--model Your-org/your-model \
--prompt "a cup of coffee on the table" \
--negative-prompt "ugly, unclear" \
--cfg-scale 4.0 \
--num-inference-steps 50 \
--output "cfg_enabled.png" \
--cfg-parallel-size 2
Verify:
- Check logs for CFG parallel being activated
- Record the
e2e_time_msin the log and compare with CFG-Parallel disabled - Compare the generated result quality with baseline
- Record comparison results in your PR
Troubleshooting¶
Issue: CFG parallel not activating¶
Symptoms: Generation still slow, logs don't show CFG parallel being used.
Causes & Solutions:
- CFG is not enabled:
Problem: Guidance scale too low or negative prompt not provided.
Solution: Ensure guidance_scale > 1.0 and negative prompt is provided:
images = pipeline(
prompt="a cat",
negative_prompt="", # Must provide (even if empty)
guidance_scale=3.5, # Must be > 1.0
)
Reference Implementations¶
Complete examples in the codebase:
| Model | Path | Pattern | Notes |
|---|---|---|---|
| Qwen-Image | vllm_omni/diffusion/models/qwen_image/cfg_parallel.py | Mixin | Dual-stream transformer |
| Qwen-Image-Edit | vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py | Mixin | Image editing with output_slice |
| Wan2.2 | vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py | Mixin | Dual-transformer architecture |
| CFGParallelMixin | vllm_omni/diffusion/distributed/cfg_parallel.py | Base implementation | Core mixin class |
Summary¶
Adding CFG-Parallel support:
- ✅ Create mixin - Inherit from
CFGParallelMixinand implementdiffuse()method - ✅ (Optional) Customize - Override
predict_noise()orcfg_normalize_function()for custom behavior - ✅ (Optional) Multi-branch - For 3+ branch models, use
predict_noise_with_multi_branch_cfg()and overridecombine_multi_branch_cfg_noise() - ✅ Test - Verify with
--cfg-parallel-size 2(or 3/4 for multi-branch) and compare performance