Skip to content

vllm_omni.diffusion.models.flux2

Flux2 diffusion model components.

Modules:

Name Description
flux2_transformer
pipeline_flux2

Flux2Pipeline

Bases: Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin

Flux2 pipeline for text-to-image generation.

attention_kwargs property

attention_kwargs

current_timestep property

current_timestep

default_sample_size instance-attribute

default_sample_size = 128

guidance_scale property

guidance_scale

image_processor instance-attribute

image_processor = Flux2ImageProcessor(
    vae_scale_factor=vae_scale_factor * 2
)

interrupt property

interrupt

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

scheduler instance-attribute

scheduler = from_pretrained(
    model,
    subfolder="scheduler",
    local_files_only=local_files_only,
)

support_image_input class-attribute instance-attribute

support_image_input = True

system_message instance-attribute

system_message = SYSTEM_MESSAGE

system_message_upsampling_i2i instance-attribute

system_message_upsampling_i2i = (
    SYSTEM_MESSAGE_UPSAMPLING_I2I
)

system_message_upsampling_t2i instance-attribute

system_message_upsampling_t2i = (
    SYSTEM_MESSAGE_UPSAMPLING_T2I
)

text_encoder instance-attribute

text_encoder = to(_execution_device)

tokenizer instance-attribute

tokenizer = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer_max_length instance-attribute

tokenizer_max_length = 512

transformer instance-attribute

transformer = Flux2Transformer2DModel(
    quant_config=quantization_config,
    od_config=od_config,
    **transformer_kwargs,
)

upsampling_max_image_size instance-attribute

upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE

vae instance-attribute

vae = to(_execution_device)

vae_scale_factor instance-attribute

vae_scale_factor = (
    2 ** (len(block_out_channels) - 1)
    if getattr(self, "vae", None)
    else 8
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    )
]

check_cfg_parallel_validity

check_cfg_parallel_validity(
    true_cfg_scale: float, has_neg_prompt: bool
)

check_inputs

check_inputs(
    prompt,
    height,
    width,
    prompt_embeds=None,
    callback_on_step_end_tensor_inputs=None,
)

encode_prompt

encode_prompt(
    prompt: str | list[str],
    device: device | None = None,
    num_images_per_prompt: int = 1,
    prompt_embeds: Tensor | None = None,
    max_sequence_length: int = 512,
    text_encoder_out_layers: tuple[int, ...] = (10, 20, 30),
)

forward

forward(
    req: OmniDiffusionRequest,
    image: Image | list[Image] | None = None,
    prompt: str | list[str] | None = None,
    height: int | None = None,
    width: int | None = None,
    num_inference_steps: int = 50,
    sigmas: list[float] | None = None,
    guidance_scale: float | None = 4.0,
    num_images_per_prompt: int = 1,
    generator: Generator | list[Generator] | None = None,
    latents: Tensor | None = None,
    prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    output_type: str | None = "pil",
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
    callback_on_step_end: Callable[[int, int, dict], None]
    | None = None,
    callback_on_step_end_tensor_inputs: list[str] = [
        "latents"
    ],
    max_sequence_length: int = 512,
    text_encoder_out_layers: tuple[int, ...] = (10, 20, 30),
    caption_upsample_temperature: float = None,
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

prepare_image_latents

prepare_image_latents(
    images: list[Tensor],
    batch_size,
    generator: Generator,
    device,
    dtype,
)

prepare_latents

prepare_latents(
    batch_size,
    num_latents_channels,
    height,
    width,
    dtype,
    device,
    generator: Generator,
    latents: Tensor | None = None,
)

upsample_prompt

upsample_prompt(
    prompt: str | list[str],
    images: list[Image] | list[list[Image]] = None,
    temperature: float = 0.15,
    device: device = None,
) -> list[str]

Flux2Transformer2DModel

Bases: Module

The Transformer model introduced in Flux 2.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.

config instance-attribute

config = SimpleNamespace(
    patch_size=patch_size,
    in_channels=in_channels,
    out_channels=out_channels,
    num_layers=num_layers,
    num_single_layers=num_single_layers,
    attention_head_dim=attention_head_dim,
    num_attention_heads=num_attention_heads,
    joint_attention_dim=joint_attention_dim,
    timestep_guidance_channels=timestep_guidance_channels,
    mlp_ratio=mlp_ratio,
    axes_dims_rope=axes_dims_rope,
    rope_theta=rope_theta,
    eps=eps,
    guidance_embeds=guidance_embeds,
)

context_embedder instance-attribute

context_embedder = Linear(
    joint_attention_dim, inner_dim, bias=False
)

double_stream_modulation_img instance-attribute

double_stream_modulation_img = Flux2Modulation(
    inner_dim, mod_param_sets=2, bias=False
)

double_stream_modulation_txt instance-attribute

double_stream_modulation_txt = Flux2Modulation(
    inner_dim, mod_param_sets=2, bias=False
)

dtype property

dtype: dtype

guidance_embeds instance-attribute

guidance_embeds = guidance_embeds

inner_dim instance-attribute

inner_dim = num_attention_heads * attention_head_dim

norm_out instance-attribute

norm_out = AdaLayerNormContinuous(
    inner_dim,
    inner_dim,
    elementwise_affine=False,
    eps=eps,
    bias=False,
)

out_channels instance-attribute

out_channels = out_channels or in_channels

parallel_config instance-attribute

parallel_config = parallel_config

pos_embed instance-attribute

pos_embed = Flux2PosEmbed(
    theta=rope_theta, axes_dim=axes_dims_rope
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim,
    patch_size * patch_size * out_channels,
    bias=False,
)

rope_prepare instance-attribute

rope_prepare = Flux2RopePrepare(pos_embed)

single_stream_modulation instance-attribute

single_stream_modulation = Flux2Modulation(
    inner_dim, mod_param_sets=1, bias=False
)

single_transformer_blocks instance-attribute

single_transformer_blocks = ModuleList(
    [
        (
            Flux2SingleTransformerBlock(
                parallel_config=parallel_config,
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                mlp_ratio=mlp_ratio,
                eps=eps,
                bias=False,
                quant_config=quant_config,
                prefix=f"single_transformer_blocks.{i}",
            )
        )
        for i in (range(num_single_layers))
    ]
)

stacked_params_mapping instance-attribute

stacked_params_mapping = None

time_guidance_embed instance-attribute

time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
    in_channels=timestep_guidance_channels,
    embedding_dim=inner_dim,
    bias=False,
    guidance_embeds=guidance_embeds,
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            Flux2TransformerBlock(
                parallel_config=parallel_config,
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                mlp_ratio=mlp_ratio,
                eps=eps,
                bias=False,
                quant_config=quant_config,
                prefix=f"transformer_blocks.{i}",
            )
        )
        for i in (range(num_layers))
    ]
)

x_embedder instance-attribute

x_embedder = Linear(in_channels, inner_dim, bias=False)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

get_flux2_post_process_func

get_flux2_post_process_func(od_config: OmniDiffusionConfig)