Skip to content

vllm_omni.diffusion.models.flux

FLUX diffusion model components.

Modules:

Name Description
flux_pipeline_mixin

Flux Pipeline Mixin - Shared methods for Flux pipelines.

flux_transformer
pipeline_flux
pipeline_flux_kontext

FluxDMD2Pipeline

Bases: DMD2PipelineMixin, FluxPipeline

Flux pipeline for FastGen DMD2-distilled models.

FluxKontextPipeline

Bases: Module, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin

FLUX.1-Kontext pipeline for image editing with text guidance.

default_sample_size instance-attribute

default_sample_size = 128

guidance_scale property

guidance_scale

image_processor instance-attribute

image_processor = VaeImageProcessor(
    vae_scale_factor=vae_scale_factor * 2
)

interrupt property

interrupt

joint_attention_kwargs property

joint_attention_kwargs

latent_channels instance-attribute

latent_channels = (
    latent_channels if hasattr(vae, "config") else 16
)

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

scheduler instance-attribute

scheduler = from_pretrained(
    model,
    subfolder="scheduler",
    local_files_only=local_files_only,
)

support_image_input class-attribute instance-attribute

support_image_input = True

text_encoder instance-attribute

text_encoder = to(_execution_device)

text_encoder_2 instance-attribute

text_encoder_2 = to(_execution_device)

tokenizer instance-attribute

tokenizer = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer_2 instance-attribute

tokenizer_2 = from_pretrained(
    model,
    subfolder="tokenizer_2",
    local_files_only=local_files_only,
)

tokenizer_max_length instance-attribute

tokenizer_max_length = (
    model_max_length
    if hasattr(self, "tokenizer") and tokenizer is not None
    else 77
)

transformer instance-attribute

transformer = FluxKontextTransformer2DModel(
    **transformer_kwargs
)

vae instance-attribute

vae = to(_execution_device)

vae_scale_factor instance-attribute

vae_scale_factor = (
    2 ** (len(block_out_channels) - 1)
    if getattr(self, "vae", None)
    else 8
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    )
]

check_inputs

check_inputs(
    prompt,
    prompt_2=None,
    height=None,
    width=None,
    negative_prompt=None,
    negative_prompt_2=None,
    prompt_embeds=None,
    pooled_prompt_embeds=None,
    negative_prompt_embeds=None,
    negative_pooled_prompt_embeds=None,
    callback_on_step_end_tensor_inputs=None,
    max_sequence_length=None,
)

encode_prompt

encode_prompt(
    prompt: str | list[str],
    prompt_2: str | list[str] | None = None,
    device: device | None = None,
    num_images_per_prompt: int = 1,
    prompt_embeds: FloatTensor | None = None,
    pooled_prompt_embeds: FloatTensor | None = None,
    max_sequence_length: int = 512,
)

forward

forward(
    req: OmniDiffusionRequest,
    image: Image | list[Image] | None = None,
    prompt: str | list[str] | None = None,
    prompt_2: str | list[str] | None = None,
    negative_prompt: str | list[str] | None = None,
    negative_prompt_2: str | list[str] | None = None,
    height: int | None = None,
    width: int | None = None,
    num_inference_steps: int = 28,
    guidance_scale: float = 3.5,
    true_cfg_scale: float = 1.0,
    num_images_per_prompt: int = 1,
    generator: Generator | list[Generator] | None = None,
    latents: Tensor | None = None,
    prompt_embeds: Tensor | None = None,
    pooled_prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    negative_pooled_prompt_embeds: Tensor | None = None,
    output_type: str | None = "pil",
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
    callback_on_step_end: Callable[[int, int, dict], None]
    | None = None,
    callback_on_step_end_tensor_inputs: list[str] = [
        "latents"
    ],
    max_sequence_length: int = 512,
    sigmas: list[float] | None = None,
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

prepare_latents

prepare_latents(
    image: Tensor | None,
    batch_size: int,
    num_channels_latents: int,
    height: int,
    width: int,
    dtype: dtype,
    device: device,
    generator: Generator | list[Generator] | None = None,
    latents: Tensor | None = None,
)

FluxKontextTransformer2DModel

Bases: FluxTransformer2DModel

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    pooled_projections: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

FluxPipeline

Bases: Module, FluxPipelineMixin, CFGParallelMixin, DiffusionPipelineProfilerMixin

current_timestep property

current_timestep

default_sample_size instance-attribute

default_sample_size = 128

device instance-attribute

device = get_local_device()

guidance_scale property

guidance_scale

interrupt property

interrupt

joint_attention_kwargs property

joint_attention_kwargs

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

scheduler instance-attribute

scheduler = from_pretrained(
    model,
    subfolder="scheduler",
    local_files_only=local_files_only,
)

stage instance-attribute

stage = None

text_encoder instance-attribute

text_encoder = to(device)

text_encoder_2 instance-attribute

text_encoder_2 = to(device)

tokenizer instance-attribute

tokenizer = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer_2 instance-attribute

tokenizer_2 = from_pretrained(
    model,
    subfolder="tokenizer_2",
    local_files_only=local_files_only,
)

tokenizer_max_length instance-attribute

tokenizer_max_length = (
    model_max_length
    if hasattr(self, "tokenizer") and tokenizer is not None
    else 77
)

transformer instance-attribute

transformer = FluxTransformer2DModel(
    **transformer_kwargs,
    od_config=od_config,
    quant_config=quantization_config,
)

vae instance-attribute

vae = to(device)

vae_scale_factor instance-attribute

vae_scale_factor = (
    2 ** (len(block_out_channels) - 1)
    if getattr(self, "vae", None)
    else 8
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    ),
    ComponentSource(
        model_or_path=model,
        subfolder="text_encoder_2",
        revision=None,
        prefix="text_encoder_2.",
        fall_back_to_pt=True,
    ),
]

check_cfg_parallel_validity

check_cfg_parallel_validity(
    true_cfg_scale: float, has_neg_prompt: bool
)

check_inputs

check_inputs(
    prompt,
    prompt_2,
    height,
    width,
    negative_prompt=None,
    negative_prompt_2=None,
    prompt_embeds=None,
    negative_prompt_embeds=None,
    pooled_prompt_embeds=None,
    negative_pooled_prompt_embeds=None,
    callback_on_step_end_tensor_inputs=None,
    max_sequence_length=None,
)

diffuse

diffuse(
    prompt_embeds: Tensor,
    pooled_prompt_embeds: Tensor,
    negative_prompt_embeds: Tensor,
    negative_pooled_prompt_embeds: Tensor,
    latents: Tensor,
    latent_image_ids: Tensor,
    text_ids: Tensor,
    negative_text_ids: Tensor,
    timesteps: Tensor,
    do_true_cfg: bool,
    guidance: Tensor,
    true_cfg_scale: float,
    cfg_normalize: bool = False,
) -> Tensor

Diffusion loop with optional image conditioning.

encode_prompt

encode_prompt(
    prompt: str | list[str],
    prompt_2: str | list[str],
    num_images_per_prompt: int = 1,
    prompt_embeds: FloatTensor | None = None,
    pooled_prompt_embeds: FloatTensor | None = None,
    max_sequence_length: int = 512,
)

Parameters:

Name Type Description Default
prompt `str` or `List[str]`, *optional*

prompt to be encoded

required
prompt_2 `str` or `List[str]`, *optional*

The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2. If not defined, prompt is used in all text-encoders

required
num_images_per_prompt `int`

number of images that should be generated per prompt

1
prompt_embeds `torch.FloatTensor`, *optional*

Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.

None
pooled_prompt_embeds `torch.FloatTensor`, *optional*

Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, pooled text embeddings will be generated from prompt input argument.

None

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | list[str] | None = None,
    prompt_2: str | list[str] | None = None,
    negative_prompt: str | list[str] | None = None,
    negative_prompt_2: str | list[str] | None = None,
    true_cfg_scale: float = 1.0,
    height: int | None = None,
    width: int | None = None,
    num_inference_steps: int = 28,
    sigmas: list[float] | None = None,
    guidance_scale: float = 3.5,
    num_images_per_prompt: int = 1,
    generator: Generator | list[Generator] | None = None,
    latents: FloatTensor | None = None,
    prompt_embeds: FloatTensor | None = None,
    pooled_prompt_embeds: FloatTensor | None = None,
    negative_prompt_embeds: FloatTensor | None = None,
    negative_pooled_prompt_embeds: FloatTensor
    | None = None,
    output_type: str | None = "pil",
    return_dict: bool = True,
    joint_attention_kwargs: dict[str, Any] | None = None,
    callback_on_step_end_tensor_inputs: list[str] = [
        "latents"
    ],
    max_sequence_length: int = 512,
)

Forward pass for flux.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

prepare_latents

prepare_latents(
    batch_size,
    num_channels_latents,
    height,
    width,
    dtype,
    device,
    generator,
    latents=None,
)

prepare_timesteps

prepare_timesteps(
    num_inference_steps, sigmas, image_seq_len
)

FluxTransformer2DModel

Bases: Module

The Transformer model introduced in Flux.

Parameters:

Name Type Description Default
od_config `OmniDiffusionConfig`

The configuration for the model.

None
patch_size `int`, defaults to `1`

Patch size to turn the input data into small patches.

1
in_channels `int`, defaults to `64`

The number of channels in the input.

64
out_channels `int`, *optional*, defaults to `None`

The number of channels in the output. If not specified, it defaults to in_channels.

None
num_layers `int`, defaults to `19`

The number of layers of dual stream DiT blocks to use.

19
num_single_layers `int`, defaults to `38`

The number of layers of single stream DiT blocks to use.

38
attention_head_dim `int`, defaults to `128`

The number of dimensions to use for each attention head.

128
num_attention_heads `int`, defaults to `24`

The number of attention heads to use.

24
joint_attention_dim `int`, defaults to `4096`

The number of dimensions to use for the joint attention (embedding/channel dimension of encoder_hidden_states).

4096
pooled_projection_dim `int`, defaults to `768`

The number of dimensions to use for the pooled projection.

768
guidance_embeds `bool`, defaults to `False`

Whether to use guidance embeddings for guidance-distilled variant of the model.

True
axes_dims_rope `Tuple[int]`, defaults to `(16, 56, 56)`

The dimensions to use for the rotary positional embeddings.

(16, 56, 56)

context_embedder instance-attribute

context_embedder = Linear(joint_attention_dim, inner_dim)

guidance_embeds instance-attribute

guidance_embeds = guidance_embeds

in_channels instance-attribute

in_channels = in_channels

inner_dim instance-attribute

inner_dim = num_attention_heads * attention_head_dim

norm_out instance-attribute

norm_out = AdaLayerNormContinuous(
    inner_dim,
    inner_dim,
    elementwise_affine=False,
    eps=1e-06,
    quant_config=_safe_quant_config(quant_config),
    prefix="norm_out",
)

out_channels instance-attribute

out_channels = out_channels or in_channels

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "to_qkv": ["to_q", "to_k", "to_v"],
    "add_kv_proj": [
        "add_q_proj",
        "add_k_proj",
        "add_v_proj",
    ],
}

parallel_config instance-attribute

parallel_config = parallel_config

pos_embed instance-attribute

pos_embed = FluxPosEmbed(
    theta=theta, axes_dim=axes_dims_rope
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim,
    patch_size * patch_size * out_channels,
    bias=True,
)

single_transformer_blocks instance-attribute

single_transformer_blocks = ModuleList(
    [
        (
            FluxSingleTransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                quant_config=quant_config,
                prefix=f"single_transformer_blocks.{i}",
            )
        )
        for i in (range(num_single_layers))
    ]
)

time_text_embed instance-attribute

time_text_embed = text_time_guidance_cls(
    embedding_dim=inner_dim,
    pooled_projection_dim=pooled_projection_dim,
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            FluxTransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                quant_config=_safe_quant_config(
                    quant_config
                ),
                prefix=f"transformer_blocks.{i}",
            )
        )
        for i in (range(num_layers))
    ]
)

x_embedder instance-attribute

x_embedder = Linear(in_channels, inner_dim)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    pooled_projections: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

The [FluxTransformer2DModel] forward method.

Parameters:

Name Type Description Default
hidden_states `torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`

Input hidden_states.

required
encoder_hidden_states `torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`

Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.

None
pooled_projections `torch.Tensor` of shape `(batch_size, projection_dim)`

Embeddings projected from the embeddings of input conditions.

None
timestep `torch.LongTensor`

Used to indicate denoising step.

None
img_ids Tensor

(torch.Tensor): The position ids for image tokens.

None
txt_ids `torch.Tensor`

The position ids for text tokens.

None
guidance `torch.Tensor`

Guidance embeddings for guidance-distilled variant of the model.

None
joint_attention_kwargs `dict`, *optional*

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

None
return_dict `bool`, *optional*, defaults to `True`

Whether or not to return a [~models.transformer_2d.Transformer2DModelOutput] instead of a plain tuple.

True

Returns:

Type Description
Tensor | Transformer2DModelOutput

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

Tensor | Transformer2DModelOutput

tuple where the first element is the sample tensor.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

get_flux_kontext_post_process_func

get_flux_kontext_post_process_func(
    od_config: OmniDiffusionConfig,
) -> Callable

Get postprocessing function for FLUX.1-Kontext pipeline.

get_flux_post_process_func

get_flux_post_process_func(od_config: OmniDiffusionConfig)