Skip to content

vllm_omni.diffusion.models.flux.pipeline_flux_kontext

logger module-attribute

logger = init_logger(__name__)

FluxKontextPipeline

Bases: Module, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin

FLUX.1-Kontext pipeline for image editing with text guidance.

default_sample_size instance-attribute

default_sample_size = 128

guidance_scale property

guidance_scale

image_processor instance-attribute

image_processor = VaeImageProcessor(
    vae_scale_factor=vae_scale_factor * 2
)

interrupt property

interrupt

joint_attention_kwargs property

joint_attention_kwargs

latent_channels instance-attribute

latent_channels = (
    latent_channels if hasattr(vae, "config") else 16
)

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

scheduler instance-attribute

scheduler = from_pretrained(
    model,
    subfolder="scheduler",
    local_files_only=local_files_only,
)

support_image_input class-attribute instance-attribute

support_image_input = True

text_encoder instance-attribute

text_encoder = to(_execution_device)

text_encoder_2 instance-attribute

text_encoder_2 = to(_execution_device)

tokenizer instance-attribute

tokenizer = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer_2 instance-attribute

tokenizer_2 = from_pretrained(
    model,
    subfolder="tokenizer_2",
    local_files_only=local_files_only,
)

tokenizer_max_length instance-attribute

tokenizer_max_length = (
    model_max_length
    if hasattr(self, "tokenizer") and tokenizer is not None
    else 77
)

transformer instance-attribute

transformer = FluxKontextTransformer2DModel(
    **transformer_kwargs
)

vae instance-attribute

vae = to(_execution_device)

vae_scale_factor instance-attribute

vae_scale_factor = (
    2 ** (len(block_out_channels) - 1)
    if getattr(self, "vae", None)
    else 8
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    )
]

check_inputs

check_inputs(
    prompt,
    prompt_2=None,
    height=None,
    width=None,
    negative_prompt=None,
    negative_prompt_2=None,
    prompt_embeds=None,
    pooled_prompt_embeds=None,
    negative_prompt_embeds=None,
    negative_pooled_prompt_embeds=None,
    callback_on_step_end_tensor_inputs=None,
    max_sequence_length=None,
)

encode_prompt

encode_prompt(
    prompt: str | list[str],
    prompt_2: str | list[str] | None = None,
    device: device | None = None,
    num_images_per_prompt: int = 1,
    prompt_embeds: FloatTensor | None = None,
    pooled_prompt_embeds: FloatTensor | None = None,
    max_sequence_length: int = 512,
)

forward

forward(
    req: OmniDiffusionRequest,
    image: Image | list[Image] | None = None,
    prompt: str | list[str] | None = None,
    prompt_2: str | list[str] | None = None,
    negative_prompt: str | list[str] | None = None,
    negative_prompt_2: str | list[str] | None = None,
    height: int | None = None,
    width: int | None = None,
    num_inference_steps: int = 28,
    guidance_scale: float = 3.5,
    true_cfg_scale: float = 1.0,
    num_images_per_prompt: int = 1,
    generator: Generator | list[Generator] | None = None,
    latents: Tensor | None = None,
    prompt_embeds: Tensor | None = None,
    pooled_prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    negative_pooled_prompt_embeds: Tensor | None = None,
    output_type: str | None = "pil",
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
    callback_on_step_end: Callable[[int, int, dict], None]
    | None = None,
    callback_on_step_end_tensor_inputs: list[str] = [
        "latents"
    ],
    max_sequence_length: int = 512,
    sigmas: list[float] | None = None,
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

prepare_latents

prepare_latents(
    image: Tensor | None,
    batch_size: int,
    num_channels_latents: int,
    height: int,
    width: int,
    dtype: dtype,
    device: device,
    generator: Generator | list[Generator] | None = None,
    latents: Tensor | None = None,
)

get_flux_kontext_post_process_func

get_flux_kontext_post_process_func(
    od_config: OmniDiffusionConfig,
) -> Callable

Get postprocessing function for FLUX.1-Kontext pipeline.