class HpuEagleProposer(EagleProposer):
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None,
common_attn_metadata: HPUUnifiedAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
):
num_tokens = target_token_ids.shape[0]
# For decode, the virtual batch_size is batch size * num_tokens
# and the seq_len is always 1
if self.method == "eagle3":
assert isinstance(self.model.model, Eagle3LlamaForCausalLM)
target_hidden_states = \
self.model.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
#print(f"{self.input_ids[:num_tokens]}, {target_token_ids}")
ret_hidden_states = self.model(
input_ids=self.input_ids[:num_tokens].unsqueeze(-1),
positions=target_positions.unsqueeze(-1),
hidden_states=target_hidden_states,
inputs_embeds=None,
attn_metadata=common_attn_metadata,
)
# All MTP related method names are now unified to "mtp"
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(-1, 1)
def prepare_inputs(
self,
sampled_token_ids: list[list[int]],
unified_data: UnifiedBatch,
):
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator. And last_token_indices
to indicate indices for the next_token_ids copied to.
"""
# E.g.
# query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
spec_decode_metadata = unified_data.spec_decode_metadata
num_draft_tokens = spec_decode_metadata.num_draft_tokens
query_start_loc_cpu = unified_data.query_start_loc_cpu
#seq_lens_cpu = unified_data.seq_lens_cpu
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
device = self.device
#new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
)
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = (self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded)
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
#last_token_indices = (new_query_start_loc_cpu[1:] - 1).to(device, non_blocking=True)
last_token_indices_with_draft = query_start_loc_cpu[1:] - 1
last_token_indices_remove_rejected = last_token_indices_with_draft - num_rejected_tokens
last_token_indices = (last_token_indices_remove_rejected).to(device, non_blocking=True)
return token_indices, last_token_indices