Skip to content

vllm_gaudi.v1.spec_decode.hpu_eagle_unified

logger module-attribute

logger = logger()

HpuEagleProposer

Bases: EagleProposer

Source code in vllm_gaudi/v1/spec_decode/hpu_eagle_unified.py
class HpuEagleProposer(EagleProposer):

    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
        last_token_indices: torch.Tensor | None,
        common_attn_metadata: HPUUnifiedAttentionMetadata,
        sampling_metadata: SamplingMetadata,
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
    ):
        num_tokens = target_token_ids.shape[0]
        # For decode, the virtual batch_size is batch size * num_tokens
        # and the seq_len is always 1

        if self.method == "eagle3":
            assert isinstance(self.model.model, Eagle3LlamaForCausalLM)
            target_hidden_states = \
                self.model.model.combine_hidden_states(
                    target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

        self.input_ids[:num_tokens - 1] = target_token_ids[1:]
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
        self.input_ids[last_token_indices] = next_token_ids

        #print(f"{self.input_ids[:num_tokens]}, {target_token_ids}")

        ret_hidden_states = self.model(
            input_ids=self.input_ids[:num_tokens].unsqueeze(-1),
            positions=target_positions.unsqueeze(-1),
            hidden_states=target_hidden_states,
            inputs_embeds=None,
            attn_metadata=common_attn_metadata,
        )

        # All MTP related method names are now unified to "mtp"
        if self.method == "mtp":
            last_hidden_states = ret_hidden_states
            hidden_states = last_hidden_states
        else:
            last_hidden_states, hidden_states = ret_hidden_states
        last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
        sample_hidden_states = last_hidden_states[last_token_indices]
        logits = self.model.compute_logits(sample_hidden_states)

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

    def prepare_inputs(
        self,
        sampled_token_ids: list[list[int]],
        unified_data: UnifiedBatch,
    ):
        """
        This function is used to prepare the inputs for speculative decoding.
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator. And last_token_indices
        to indicate indices for the next_token_ids copied to.
        """
        # E.g.
        #  query_start_loc{_cpu}:
        #       [0, q1, q1 + q2, q1 + q2 + q3]
        #  seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  query_start_loc{_cpu}:
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        #  seq_lens{_cpu}:
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

        spec_decode_metadata = unified_data.spec_decode_metadata
        num_draft_tokens = spec_decode_metadata.num_draft_tokens
        query_start_loc_cpu = unified_data.query_start_loc_cpu
        #seq_lens_cpu = unified_data.seq_lens_cpu
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens)
        ]
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)

        device = self.device
        #new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
            dtype=torch.int32,
        )
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
        new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], new_num_tokens_per_req_np)
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
        token_offests = (self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded)

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
        # Final token indices are:
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
        token_indices_np = token_offests + old_query_start_locs_expanded
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)

        #last_token_indices = (new_query_start_loc_cpu[1:] - 1).to(device, non_blocking=True)
        last_token_indices_with_draft = query_start_loc_cpu[1:] - 1
        last_token_indices_remove_rejected = last_token_indices_with_draft - num_rejected_tokens
        last_token_indices = (last_token_indices_remove_rejected).to(device, non_blocking=True)

        return token_indices, last_token_indices

prepare_inputs

prepare_inputs(
    sampled_token_ids: list[list[int]],
    unified_data: UnifiedBatch,
)

This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. And last_token_indices to indicate indices for the next_token_ids copied to.

Source code in vllm_gaudi/v1/spec_decode/hpu_eagle_unified.py
def prepare_inputs(
    self,
    sampled_token_ids: list[list[int]],
    unified_data: UnifiedBatch,
):
    """
    This function is used to prepare the inputs for speculative decoding.
    It updates to the common_attn_metadata to account for the rejected
    tokens (and newly sampled tokens). It also returns the token indices
    of the tokens that should be fed to the speculator. And last_token_indices
    to indicate indices for the next_token_ids copied to.
    """
    # E.g.
    #  query_start_loc{_cpu}:
    #       [0, q1, q1 + q2, q1 + q2 + q3]
    #  seq_lens{_cpu}: [s1, s2, s3]
    #  num_rejected_tokens: [n1, n2, n3]
    # This function computes the intermediate values:
    #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
    # And returns:
    #  query_start_loc{_cpu}:
    #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
    #  seq_lens{_cpu}:
    #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
    #  token_indices: [0, 1, ..., q1 - n1 - 1,
    #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
    #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]

    spec_decode_metadata = unified_data.spec_decode_metadata
    num_draft_tokens = spec_decode_metadata.num_draft_tokens
    query_start_loc_cpu = unified_data.query_start_loc_cpu
    #seq_lens_cpu = unified_data.seq_lens_cpu
    num_rejected_tokens = [
        n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens)
    ]
    num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)

    device = self.device
    #new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens

    # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
    new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
    # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
    new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
    new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

    # [q1 - n1, q2 - n2, q3 - n3] ->
    # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
    new_query_start_loc_cpu = torch.zeros(
        query_start_loc_cpu.shape,
        dtype=torch.int32,
    )
    new_query_start_loc_np = new_query_start_loc_cpu.numpy()
    np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

    total_num_tokens = new_query_start_loc_np[-1]
    # Example assuming num_tokens_per_req_np = [2, 4, 3]
    # this implies that `new_query_start_locs` is:
    # [0, 2, 6, 9] ->
    # [0, 0, 2, 2, 2, 2, 6, 6, 6]
    #  _r1_  ____r2____  ___r3__
    new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], new_num_tokens_per_req_np)
    # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
    # [0, 1, 0, 1, 2, 3, 0, 1, 2]
    #  _r1_  ____r2____  ___r3__
    token_offests = (self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded)

    # Expand starting positions to match token pattern
    # [0, q1, q1 + q2] ->
    # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
    #  _r1_  _____r2_______  ___________r3____________
    old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
    # Final token indices are:
    # [0, 1,                                // req 1
    #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
    #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
    token_indices_np = token_offests + old_query_start_locs_expanded
    token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)

    #last_token_indices = (new_query_start_loc_cpu[1:] - 1).to(device, non_blocking=True)
    last_token_indices_with_draft = query_start_loc_cpu[1:] - 1
    last_token_indices_remove_rejected = last_token_indices_with_draft - num_rejected_tokens
    last_token_indices = (last_token_indices_remove_rejected).to(device, non_blocking=True)

    return token_indices, last_token_indices

propose

propose(
    target_token_ids: Tensor,
    target_positions: Tensor,
    target_hidden_states: Tensor,
    next_token_ids: Tensor,
    last_token_indices: Tensor | None,
    common_attn_metadata: HPUUnifiedAttentionMetadata,
    sampling_metadata: SamplingMetadata,
    mm_embed_inputs: tuple[list[Tensor], Tensor]
    | None = None,
)
Source code in vllm_gaudi/v1/spec_decode/hpu_eagle_unified.py
def propose(
    self,
    # [num_tokens]
    target_token_ids: torch.Tensor,
    # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
    target_positions: torch.Tensor,
    # [num_tokens, hidden_size]
    target_hidden_states: torch.Tensor,
    # [batch_size]
    next_token_ids: torch.Tensor,
    last_token_indices: torch.Tensor | None,
    common_attn_metadata: HPUUnifiedAttentionMetadata,
    sampling_metadata: SamplingMetadata,
    mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
):
    num_tokens = target_token_ids.shape[0]
    # For decode, the virtual batch_size is batch size * num_tokens
    # and the seq_len is always 1

    if self.method == "eagle3":
        assert isinstance(self.model.model, Eagle3LlamaForCausalLM)
        target_hidden_states = \
            self.model.model.combine_hidden_states(
                target_hidden_states)
        assert target_hidden_states.shape[-1] == self.hidden_size

    self.input_ids[:num_tokens - 1] = target_token_ids[1:]
    # Replace the last token with the next token.
    # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
    self.input_ids[last_token_indices] = next_token_ids

    #print(f"{self.input_ids[:num_tokens]}, {target_token_ids}")

    ret_hidden_states = self.model(
        input_ids=self.input_ids[:num_tokens].unsqueeze(-1),
        positions=target_positions.unsqueeze(-1),
        hidden_states=target_hidden_states,
        inputs_embeds=None,
        attn_metadata=common_attn_metadata,
    )

    # All MTP related method names are now unified to "mtp"
    if self.method == "mtp":
        last_hidden_states = ret_hidden_states
        hidden_states = last_hidden_states
    else:
        last_hidden_states, hidden_states = ret_hidden_states
    last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
    sample_hidden_states = last_hidden_states[last_token_indices]
    logits = self.model.compute_logits(sample_hidden_states)

    # Early exit if there is only one draft token to be generated.
    if self.num_speculative_tokens == 1:
        draft_token_ids = logits.argmax(dim=-1)
        return draft_token_ids.view(-1, 1)