Skip to content

vllm_omni.diffusion.models.utils

Style module-attribute

Style = Literal[
    "colwise",
    "colwise_rep",
    "rowwise",
    "rowwise_rep",
    "replicate",
]

create_transformers_model

create_transformers_model(
    auto_cls: _BaseAutoModelClass,
    od_config: OmniDiffusionConfig,
    hf_config: PretrainedConfig,
    dtype: dtype | None = None,
    device: device | None = None,
) -> PreTrainedModel

Create a HuggingFace model using the given auto class and model name.

init_parameters

init_parameters(
    module: Module,
    dtype: dtype | None,
    device: device | None = None,
)

recursive_replace_linear

recursive_replace_linear(
    model: Module, od_config: OmniDiffusionConfig
)

Recursively replace modules in the model as needed. Currently, this replaces: - nn.Linear with vLLM's tensor parallel linear classes

replace_linear_class

replace_linear_class(
    linear: Linear,
    style: Style = "replicate",
    quant_config: QuantizationConfig | None = None,
    *,
    prefix: str = "",
) -> (
    ColumnParallelLinear
    | RowParallelLinear
    | ReplicatedLinear
)

Replace nn.Linear with one of vLLM's tensor parallel linear classes.

Parameters:

Name Type Description Default
linear Linear

nn.Linear to be replaced.

required
style Style

Tensor parallel style of the new linear, e.g. "colwise".

'replicate'
quant_config QuantizationConfig | None

Quantization config for the new linear.

None

Returns: The new linear.