Skip to content

vllm_omni.diffusion.layers.mot.mot_qkv_parallel_linear

logger module-attribute

logger = init_logger(__name__)

MoTQKVParallelLinear

Bases: QKVParallelLinear

QKVParallelLinear with Mixture-of-Tokens routing.

stored directly on self (self.weight, self.weight_scale, ...),

created through the standard QKVParallelLinear.init process.

stored in the permanent submodule self.gen_exp

(self.gen_exp.weight, ...), created via quant_method.create_weights(self.gen_exp, ...). gen_exp.quant_method points to the same quant_method, so that the vLLM framework’s process_weights_after_loading can automatically detect and process it.

Forward behavior
  • und mode (text_indices is None): fully reuse super().forward()
  • gen mode: call the MoT fused GEMM kernel

gen_exp instance-attribute

gen_exp = torch.nn.Module()

forward

forward(
    input_: Tensor,
    text_indices: Tensor | None = None,
    vae_indices: Tensor | None = None,
) -> Tensor | tuple[Tensor, Parameter | None]