vllm_omni.diffusion.layers.mot.mot_row_parallel_linear ¶
MoTRowParallelLinear ¶
Bases: RowParallelLinear
RowParallelLinear with Mixture-of-Tokens routing.
directly on self (self.weight, self.weight_scale, ...),
created by RowParallelLinear.init standard process.
vae weights: on permanent submodule self.gen_exp (self.gen_exp.weight, ...), created by quant_method.create_weights(self.gen_exp, ...). gen_exp.quant_method points to the same quant_method, enabling vLLM framework's process_weights_after_loading to automatically discover and process it.
Forward behavior
- und mode (text_indices is None): fully reuse super().forward()
- gen mode: call MoT fused GEMM kernel, then execute TP all-reduce