Bases: Qwen3MoeSparseMoeBlock
Source code in vllm_gaudi/models/qwen3_moe.py
| class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]
hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
num_tokens = hs.shape[0]
if getattr(self, "is_sequence_parallel", False):
hs = sequence_parallel_chunk(hs)
router_logits, _ = self.gate(hs)
out = self.experts(hidden_states=hs, router_logits=router_logits)
if getattr(self, "is_sequence_parallel", False):
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]
return out.reshape(*orig_shape[:-1], hidden_dim)
|
forward
Source code in vllm_gaudi/models/qwen3_moe.py
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]
hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
num_tokens = hs.shape[0]
if getattr(self, "is_sequence_parallel", False):
hs = sequence_parallel_chunk(hs)
router_logits, _ = self.gate(hs)
out = self.experts(hidden_states=hs, router_logits=router_logits)
if getattr(self, "is_sequence_parallel", False):
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]
return out.reshape(*orig_shape[:-1], hidden_dim)
|