Skip to content

vllm_omni.diffusion.models.longcat_image

Modules:

Name Description
longcat_image_transformer
pipeline_longcat_image
pipeline_longcat_image_edit

LongCatImagePipeline

Bases: Module, CFGParallelMixin, DiffusionPipelineProfilerMixin

default_sample_size instance-attribute

default_sample_size = 128

device instance-attribute

device = get_local_device()

do_classifier_free_guidance property

do_classifier_free_guidance

od_config instance-attribute

od_config = od_config

prompt_template_encode_prefix instance-attribute

prompt_template_encode_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"

prompt_template_encode_suffix instance-attribute

prompt_template_encode_suffix = (
    "<|im_end|>\n<|im_start|>assistant\n"
)

scheduler instance-attribute

scheduler = from_pretrained(
    model,
    subfolder="scheduler",
    local_files_only=local_files_only,
)

text_encoder instance-attribute

text_encoder = from_pretrained_with_prefetch(
    from_pretrained,
    model,
    subfolder="text_encoder",
    prefetch_list=longcat_subfolders,
    local_files_only=local_files_only,
)

text_processor instance-attribute

text_processor = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer instance-attribute

tokenizer = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer_max_length instance-attribute

tokenizer_max_length = 512

transformer instance-attribute

transformer = LongCatImageTransformer2DModel(
    od_config=od_config
)

vae instance-attribute

vae = to(device)

vae_scale_factor instance-attribute

vae_scale_factor = (
    2 ** (len(block_out_channels) - 1)
    if getattr(self, "vae", None)
    else 8
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    )
]

cfg_normalize_function

cfg_normalize_function(
    noise_pred, comb_pred, cfg_renorm_min=0.0
)

Normalize the combined noise prediction.

check_inputs

check_inputs(
    prompt,
    height,
    width,
    negative_prompt=None,
    prompt_embeds=None,
    negative_prompt_embeds=None,
)

encode_prompt

encode_prompt(
    prompt: str | list[str] | None = None,
    num_images_per_prompt: int | None = 1,
    prompt_embeds: Tensor | None = None,
) -> tuple[Tensor, Tensor]

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | list[str] | None = None,
    negative_prompt: str | list[str] | None = None,
    height: int | None = None,
    width: int | None = None,
    num_inference_steps: int = 50,
    sigmas: list[float] | None = None,
    guidance_scale: float = 4.5,
    num_images_per_prompt: int = 1,
    generator: Generator | list[Generator] | None = None,
    latents: FloatTensor | None = None,
    prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    output_type: str | None = "pil",
    return_dict: bool = True,
    joint_attention_kwargs: dict[str, Any] | None = None,
    enable_cfg_renorm: bool | None = True,
    cfg_renorm_min: float | None = 0.0,
    enable_prompt_rewrite: bool | None = True,
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights using AutoWeightsLoader for vLLM integration.

prepare_latents

prepare_latents(
    batch_size,
    num_channels_latents,
    height,
    width,
    dtype,
    device,
    generator,
    latents=None,
)

rewire_prompt

rewire_prompt(prompt, device)

LongCatImageTransformer2DModel

Bases: Module

The Transformer model introduced in Flux.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.

context_embedder instance-attribute

context_embedder = Linear(joint_attention_dim, inner_dim)

gradient_checkpointing instance-attribute

gradient_checkpointing = False

inner_dim instance-attribute

inner_dim = num_attention_heads * attention_head_dim

norm_out instance-attribute

norm_out = AdaLayerNormContinuous(
    inner_dim,
    inner_dim,
    elementwise_affine=False,
    eps=1e-06,
)

out_channels instance-attribute

out_channels = in_channels

parallel_config instance-attribute

parallel_config = parallel_config

pooled_projection_dim instance-attribute

pooled_projection_dim = pooled_projection_dim

pos_embed instance-attribute

pos_embed = LongCatImagePosEmbed(
    theta=10000, axes_dim=axes_dims_rope
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim,
    patch_size * patch_size * out_channels,
    bias=True,
)

rope_preparer instance-attribute

rope_preparer = RoPEPreparer(pos_embed)

single_transformer_blocks instance-attribute

single_transformer_blocks = ModuleList(
    [
        (
            LongCatImageSingleTransformerBlock(
                parallel_config=parallel_config,
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
            )
        )
        for i in (range(num_single_layers))
    ]
)

time_embed instance-attribute

time_embed = LongCatImageTimestepEmbeddings(
    embedding_dim=inner_dim
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            LongCatImageTransformerBlock(
                parallel_config=parallel_config,
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
            )
        )
        for i in (range(num_layers))
    ]
)

use_checkpoint instance-attribute

use_checkpoint = [True] * num_layers

use_single_checkpoint instance-attribute

use_single_checkpoint = [True] * num_single_layers

x_embedder instance-attribute

x_embedder = Linear(in_channels, inner_dim)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor = None,
    return_dict: bool = True,
) -> FloatTensor | Transformer2DModelOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

get_longcat_image_post_process_func

get_longcat_image_post_process_func(
    od_config: OmniDiffusionConfig,
)