vllm.model_executor.layers.fused_moe.modular_kernel
FusedMoEModularKernel
¶
Bases: Module
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the fused_experts function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model layer due to any layer specific state that may be used by the component objects.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | |
__init__
¶
__init__(
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
)
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
forward
¶
forward(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[Tensor] = None,
w1_scale: Optional[Tensor] = None,
w2_scale: Optional[Tensor] = None,
w1_zp: Optional[Tensor] = None,
w2_zp: Optional[Tensor] = None,
a1_scale: Optional[Tensor] = None,
a2_scale: Optional[Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> Tensor
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Parameters: - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. - topk_ids (torch.Tensor): A map of row to expert id. - inplace (bool): If True, perform the operation in-place. Defaults to False. - activation (str): The activation function to apply after the first MoE layer. - global_num_experts (int): The total number of experts in the global expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for w1. - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for w2. - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1.
Returns: - torch.Tensor: The output tensor after applying the MoE layer.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | |
FusedMoEPermuteExpertsUnpermute
¶
Bases: ABC
An abstract base class for the [Permute-Experts-Unpermute] step described above.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |
activation
¶
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
apply
abstractmethod
¶
apply(
hidden_states: Tensor,
w1: Tensor,
w2: Tensor,
topk_ids: Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[Tensor],
w1_scale: Optional[Tensor],
w2_scale: Optional[Tensor],
w1_zp: Optional[Tensor],
w2_zp: Optional[Tensor],
a1q_scale: Optional[Tensor],
a2_scale: Optional[Tensor],
workspace13: Tensor,
workspace2: Tensor,
expert_num_tokens: Optional[Tensor],
) -> Tensor
This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first MoE layer.
- global_num_experts (int): The total number of experts in the global expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation function.
- expert_num_tokens: An optional tensor containing the number of tokens assigned to each expert when using batched experts format input.
Returns: - torch.Tensor: The unweighted, unreduced output tensor
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
workspace_shapes
abstractmethod
¶
workspace_shapes(
a: Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, dtype]
Compute the number of elements for the temporary outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm.
Returns a tuple of: - Number of workspace13 elements: must be large enough to hold the result of either expert gemm. - Number of workspace2 elements: must be large enough to hold the result of the activation function. - Workspace type: The dtype to use for the workspace tensors.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
FusedMoEPrepareAndFinalize
¶
Bases: ABC
An abstract base class for the [Quantize-Prepare] and [Finalize] steps described above.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
finalize
abstractmethod
¶
finalize(
output: Tensor,
fused_expert_output: Tensor,
topk_weights: Tensor,
topk_ids: Tensor,
apply_router_weight_on_input: bool,
) -> None
Perform any combine plus apply weights and perform a reduction on the fused experts output. - output: The output tensor, written in place. Must be (M, K) shape. - fused_expert_output: The unweighted, unreduced output of the fused experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. - apply_router_weight_on_input: When False, apply the weights to fused_expert_output.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
prepare
abstractmethod
¶
prepare(
a1: Tensor,
a1_scale: Optional[Tensor],
a2_scale: Optional[Tensor],
topk_weights: Tensor,
topk_ids: Tensor,
num_experts: int,
expert_map: Optional[Tensor],
apply_router_weight_on_input: bool,
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]
Perform any quantization (and/or) dispatching needed for this kernel. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. - topk_ids: The topk ids. - topk_weights: The topk weights. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching.
Returns a tuple of: - quantized + dispatched a. - quantized + dispatched a1_scales.
Source code in vllm/model_executor/layers/fused_moe/modular_kernel.py
_moe_problem_size
¶
_moe_problem_size(
a1: Tensor, w1: Tensor, w2: Tensor, topk_ids: Tensor
) -> tuple[int, int, int, int, int]
Extract the MoE problem size from the given tensor arguments: - a: The hidden states, input to the MoE layer. - w1: The first set of expert weights. - w2: The second set of expert weights. - topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation tensors is not obvious. It needs to be done this way specifically due to subtle issues with particular kernels, e.g. the int4 kernels divide the trailing dimension by two, so it's not "correct" to extract N or K from the trailing dimension of w1 or w2. Similarly, some kernels transpose the weights, so this needs to be kept in mind.