Bases: Qwen3MoeSparseMoeBlock
Override forward to handle 3D tensor input (B,S,H) -> (B*S,H)
and SharedFusedMoE tuple returns.
Source code in vllm_gaudi/models/qwen3_moe.py
| class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):
"""
Override forward to handle 3D tensor input (B,S,H) -> (B*S,H)
and SharedFusedMoE tuple returns.
"""
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]
hs = hidden_states.reshape(-1, hidden_dim) # (B*S, H)
num_tokens = hs.shape[0]
is_seq_parallel = getattr(self, "is_sequence_parallel", False)
if is_seq_parallel:
hs = sequence_parallel_chunk(hs)
router_logits, _ = self.gate(hs)
# SharedFusedMoE returns (shared_out, fused_out)
experts_out = self.experts(hidden_states=hs, router_logits=router_logits)
if isinstance(experts_out, tuple):
if len(experts_out) != 2:
raise RuntimeError(f"unexpected experts() tuple length={len(experts_out)}; "
"expected (shared_out, fused_out).")
shared_out, fused_out = experts_out
if fused_out is None:
raise RuntimeError("experts() returned fused_out=None")
out = fused_out if shared_out is None else (shared_out + fused_out)
else:
# backward compatibility (FusedMoE)
out = experts_out
if is_seq_parallel:
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]
return out.reshape(*orig_shape[:-1], hidden_dim)
|
forward
Source code in vllm_gaudi/models/qwen3_moe.py
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]
hs = hidden_states.reshape(-1, hidden_dim) # (B*S, H)
num_tokens = hs.shape[0]
is_seq_parallel = getattr(self, "is_sequence_parallel", False)
if is_seq_parallel:
hs = sequence_parallel_chunk(hs)
router_logits, _ = self.gate(hs)
# SharedFusedMoE returns (shared_out, fused_out)
experts_out = self.experts(hidden_states=hs, router_logits=router_logits)
if isinstance(experts_out, tuple):
if len(experts_out) != 2:
raise RuntimeError(f"unexpected experts() tuple length={len(experts_out)}; "
"expected (shared_out, fused_out).")
shared_out, fused_out = experts_out
if fused_out is None:
raise RuntimeError("experts() returned fused_out=None")
out = fused_out if shared_out is None else (shared_out + fused_out)
else:
# backward compatibility (FusedMoE)
out = experts_out
if is_seq_parallel:
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]
return out.reshape(*orig_shape[:-1], hidden_dim)
|