Skip to content

vllm_omni.diffusion.model_loader.checkpoint_adapters.modelopt

DEFAULT_PACKED_MODULES_MAPPING module-attribute

DEFAULT_PACKED_MODULES_MAPPING = {
    "to_qkv": ("to_q", "to_k", "to_v"),
    "add_kv_proj": (
        "add_q_proj",
        "add_k_proj",
        "add_v_proj",
    ),
    "w13": ("w1", "w3"),
}

FP8_DTYPES module-attribute

FP8_DTYPES = tuple(
    dtype
    for dtype in (
        getattr(torch, "float8_e4m3fn", None),
        getattr(torch, "float8_e5m2", None),
        getattr(torch, "float8_e4m3fnuz", None),
        getattr(torch, "float8_e5m2fnuz", None),
    )
    if dtype is not None
)

MODEL_OPT_SCALE_SUFFIXES module-attribute

MODEL_OPT_SCALE_SUFFIXES = (
    ".input_scale",
    ".weight_scale",
    ".weight_scale_2",
    ".weight_scale_inv",
)

logger module-attribute

logger = init_logger(__name__)

ModelOptFp8CheckpointAdapter

adapt

adapt(
    weights: Iterable[tuple[str, Tensor]],
) -> Generator[tuple[str, Tensor], None, None]

is_compatible classmethod

is_compatible(
    source: object,
    quant_config: object | None,
    use_safetensors: bool,
) -> bool

ModelOptMixedPrecisionCheckpointAdapter

ModelOptNvFp4CheckpointAdapter