vllm_omni.diffusion.models.utils ¶
Style module-attribute ¶
Style = Literal[
"colwise",
"colwise_rep",
"rowwise",
"rowwise_rep",
"replicate",
]
create_transformers_model ¶
create_transformers_model(
auto_cls: _BaseAutoModelClass,
od_config: OmniDiffusionConfig,
hf_config: PretrainedConfig,
dtype: dtype | None = None,
device: device | None = None,
) -> PreTrainedModel
Create a HuggingFace model using the given auto class and model name.
init_parameters ¶
recursive_replace_linear ¶
recursive_replace_linear(
model: Module, od_config: OmniDiffusionConfig
)
Recursively replace modules in the model as needed. Currently, this replaces: - nn.Linear with vLLM's tensor parallel linear classes
replace_linear_class ¶
replace_linear_class(
linear: Linear,
style: Style = "replicate",
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> (
ColumnParallelLinear
| RowParallelLinear
| ReplicatedLinear
)
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
linear | Linear |
| required |
style | Style | Tensor parallel style of the new linear, e.g. "colwise". | 'replicate' |
quant_config | QuantizationConfig | None | Quantization config for the new linear. | None |
Returns: The new linear.