Skip to content

vllm_omni.diffusion.layers.mot.mot_row_parallel_linear

MoTRowParallelLinear

Bases: RowParallelLinear

RowParallelLinear with Mixture-of-Tokens routing.

directly on self (self.weight, self.weight_scale, ...),

created by RowParallelLinear.init standard process.

vae weights: on permanent submodule self.gen_exp (self.gen_exp.weight, ...), created by quant_method.create_weights(self.gen_exp, ...). gen_exp.quant_method points to the same quant_method, enabling vLLM framework's process_weights_after_loading to automatically discover and process it.

Forward behavior
  • und mode (text_indices is None): fully reuse super().forward()
  • gen mode: call MoT fused GEMM kernel, then execute TP all-reduce

gen_exp instance-attribute

gen_exp = torch.nn.Module()

forward

forward(
    input_: Tensor,
    text_indices: Tensor | None = None,
    vae_indices: Tensor | None = None,
) -> Tensor | tuple[Tensor, Parameter | None]