Skip to content

vllm_gaudi.v1.spec_decode.hpu_eagle

HpuEagleProposer

Bases: EagleProposer

Source code in vllm_gaudi/v1/spec_decode/hpu_eagle.py
class HpuEagleProposer(EagleProposer):

    def propose(
        self,
        # [virtual_batch_size, seq_len]
        target_token_ids,
        # [virtual_batch_size, seq_len]
        target_positions,
        # [virtual_batch_size, seq_len, hidden_size]
        target_hidden_states,
        # [batch_size]
        last_token_indices,
        common_attn_metadata,
        # [num_seq, total_blocks]
        block_table_cpu_tensor,
        model_runner,
    ):
        # For decode, the virtual batch_size is batch size * num_tokens
        # and the seq_len is always 1
        batch_size = last_token_indices.shape[0]

        if self.method == "eagle3":
            assert isinstance(self.model.model, Eagle3LlamaForCausalLM)
            target_hidden_states = \
                self.model.model.combine_hidden_states(
                    target_hidden_states)
            assert target_hidden_states.shape[-1] == self.hidden_size

        ret_hidden_states = self.model(
            input_ids=target_token_ids,
            positions=target_positions,
            hidden_states=target_hidden_states,
            inputs_embeds=None,
            attn_metadata=common_attn_metadata,
        )

        # All MTP related method names are now unified to "mtp"
        if self.method == "mtp":
            last_hidden_states = ret_hidden_states
            hidden_states = last_hidden_states
        else:
            last_hidden_states, hidden_states = ret_hidden_states
        last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
        sample_hidden_states = last_hidden_states[last_token_indices]
        logits = self.model.compute_logits(sample_hidden_states)

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

        # [num_tokens, 1]
        target_positions = target_positions.view(-1)
        # [batch_size]
        positions = target_positions[last_token_indices]
        if self.method == "mtp":
            hidden_states = target_hidden_states.view(-1, target_hidden_states.shape[-1])
        else:
            hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

        # [batch_size, hidden_size]
        hidden_states = hidden_states[last_token_indices]

        # The first draft tokens
        draft_token_ids = logits.argmax(dim=-1)
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

        # Positions used by prepare_attn_metadata needs to be cpu because
        # compile only mode for warmup will not do any real computations
        target_positions_cpu = target_positions.cpu()
        positions_cpu = target_positions_cpu[last_token_indices.cpu()]

        # Decode 1 token each time
        for token_index in range(self.num_speculative_tokens - 1):
            # Update the inputs.
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            # [batch_size]
            input_ids = draft_token_ids_list[-1].int()

            positions += 1
            exceeds_max_model_len = positions >= self.max_model_len
            clamped_positions = torch.where(exceeds_max_model_len, 0, positions)

            # Prepare the attn metadata
            positions_cpu += 1
            attn_metadata = self.prepare_attn_metadata(block_table_cpu_tensor, positions_cpu, model_runner)

            # [batch_size, 1]
            input_ids = input_ids.view(-1, 1)
            # [batch_size, 1]
            input_positions = clamped_positions.view(-1, 1)
            # [batch_size, 1, hidden_size]
            input_hidden_states = hidden_states.view(-1, 1, hidden_states.shape[-1])
            inputs_embeds = None

            ret_hidden_states = self.model(
                input_ids=input_ids,
                positions=input_positions,
                hidden_states=input_hidden_states,
                inputs_embeds=inputs_embeds,
                attn_metadata=attn_metadata,
            )
            if self.method == "mtp":
                last_hidden_states = ret_hidden_states
                hidden_states = ret_hidden_states
            else:
                last_hidden_states, hidden_states = ret_hidden_states

            # The shape of the returned hidden_states and last_hidden_states:
            # [batch_size, 1, hidden_size]
            # viewed to: [batch_size, hidden_size]
            last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
            hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

            hidden_states = hidden_states[:batch_size]
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
            draft_token_ids = logits.argmax(dim=-1)
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
        return draft_token_ids

    def prepare_inputs(
        self,
        common_attn_metadata,
        spec_decode_metadata: SpecDecodeMetadata,
        sampled_token_ids: list[list[int]],
    ):
        assert spec_decode_metadata is not None
        num_draft_tokens = \
            spec_decode_metadata.num_draft_tokens
        max_num_draft_tokens = max(num_draft_tokens)

        num_picked_token_indices = []
        last_token_indices = []
        starting_index = 0
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens)
        ]
        for i, n in enumerate(num_draft_tokens):
            r = num_rejected_tokens[i]
            step = max_num_draft_tokens + 1
            for j in range(step):
                if j == n - r:
                    last_token_indices.append(starting_index + j)
                if j < n + 1 - r:
                    num_picked_token_indices.append(starting_index + j)
                else:
                    num_picked_token_indices.append(-1)
            starting_index += step
        hidden_states_indices = torch.tensor(num_picked_token_indices, device=self.device)
        last_token_indices = torch.tensor(last_token_indices, device=self.device)
        return common_attn_metadata, hidden_states_indices, last_token_indices

    def prepare_attn_metadata(
            self,
            # [num_seq, total_blocks]
            block_table_cpu_tensor,
            # CPU tensor: [batch_size]
            positions,
            model_runner):
        # Prepare attn metadata on CPU. (Improve for pure HPU based attn metadata preparation)
        batch_size = positions.shape[0]
        exceeds_max_model_len = positions >= self.max_model_len
        clamped_positions = torch.where(exceeds_max_model_len, 0, positions)

        # Note: block_table_cpu_tensor doesn't include the padding
        # which might smaller than the (padded) batch_size
        num_seq = block_table_cpu_tensor.shape[0]

        # Prepare block tables list
        # block_tables_list is a nested list of shape [num_seq, num_blocks]
        # num_blocks should include the slots needed for the current token
        # positions are the context lengths, and we need +1 for num_blocks
        num_blocks = torch.ceil((positions + 1) / self.block_size).int()
        num_blocks = num_blocks[:num_seq].tolist()
        block_tables_list = []
        for i, n in enumerate(num_blocks):
            seq_block_table = block_table_cpu_tensor[i, :n].tolist()
            assert len(seq_block_table) == n
            block_tables_list.append(seq_block_table)
        # Needs to be resolved by defragmenter
        block_tables_list = model_runner.defragmenter.resolve_all(block_tables_list)

        # Compute slot mapping in [batch_size, 1] shape
        clamped_positions = clamped_positions.view(-1, 1)
        block_numbers = clamped_positions // self.block_size

        # Limit with num_seq because block_table_cpu_tensor is in the shape [num_seq, x]
        block_numbers = block_numbers.to(torch.int64)[:num_seq]
        block_ids = torch.ones((batch_size, 1), dtype=torch.int32) * model_runner._PAD_BLOCK_ID
        block_ids[:num_seq] = block_table_cpu_tensor.gather(dim=1, index=block_numbers)
        # Needs to be resolved by defragmenter
        block_ids.apply_(model_runner.defragmenter.resolve)

        # Calculate the slot mapping and fill with padding
        slot_mapping = block_ids * self.block_size + clamped_positions % self.block_size
        dummy_slots = itertools.cycle(range(model_runner._PAD_SLOT_ID, model_runner._PAD_SLOT_ID + self.block_size))
        slot_mapping[num_seq:].apply_(lambda _, ds=dummy_slots: next(ds))
        # Slot mapping needs to be int64 (long) type
        slot_mapping = slot_mapping.to(torch.int64)

        block_list, block_groups, block_usage = \
            model_runner.get_habana_paged_attn_buffers(
                block_tables_list,
                slot_mapping.tolist(),
                batch_size
            )

        block_list_device = async_h2d_copy(block_list, device=self.device)
        block_usage_device = async_h2d_copy(block_usage, device=self.device)
        block_groups_device = async_h2d_copy(block_groups, device=self.device)
        slot_mapping_device = async_h2d_copy(slot_mapping, device=self.device)

        common_attn_metadata = HPUAttentionMetadataV1.make_decode_metadata(
            block_list=block_list_device,
            block_usage=block_usage_device,
            block_groups=block_groups_device,
            input_positions=None,
            slot_mapping=slot_mapping_device,
            block_size=self.block_size,
            window_block_list=None,
            window_block_usage=None,
            window_block_groups=None,
            chunked_block_list=None,
            chunked_block_usage=None,
            chunked_block_groups=None,
        )

        return common_attn_metadata

prepare_attn_metadata

prepare_attn_metadata(
    block_table_cpu_tensor, positions, model_runner
)
Source code in vllm_gaudi/v1/spec_decode/hpu_eagle.py
def prepare_attn_metadata(
        self,
        # [num_seq, total_blocks]
        block_table_cpu_tensor,
        # CPU tensor: [batch_size]
        positions,
        model_runner):
    # Prepare attn metadata on CPU. (Improve for pure HPU based attn metadata preparation)
    batch_size = positions.shape[0]
    exceeds_max_model_len = positions >= self.max_model_len
    clamped_positions = torch.where(exceeds_max_model_len, 0, positions)

    # Note: block_table_cpu_tensor doesn't include the padding
    # which might smaller than the (padded) batch_size
    num_seq = block_table_cpu_tensor.shape[0]

    # Prepare block tables list
    # block_tables_list is a nested list of shape [num_seq, num_blocks]
    # num_blocks should include the slots needed for the current token
    # positions are the context lengths, and we need +1 for num_blocks
    num_blocks = torch.ceil((positions + 1) / self.block_size).int()
    num_blocks = num_blocks[:num_seq].tolist()
    block_tables_list = []
    for i, n in enumerate(num_blocks):
        seq_block_table = block_table_cpu_tensor[i, :n].tolist()
        assert len(seq_block_table) == n
        block_tables_list.append(seq_block_table)
    # Needs to be resolved by defragmenter
    block_tables_list = model_runner.defragmenter.resolve_all(block_tables_list)

    # Compute slot mapping in [batch_size, 1] shape
    clamped_positions = clamped_positions.view(-1, 1)
    block_numbers = clamped_positions // self.block_size

    # Limit with num_seq because block_table_cpu_tensor is in the shape [num_seq, x]
    block_numbers = block_numbers.to(torch.int64)[:num_seq]
    block_ids = torch.ones((batch_size, 1), dtype=torch.int32) * model_runner._PAD_BLOCK_ID
    block_ids[:num_seq] = block_table_cpu_tensor.gather(dim=1, index=block_numbers)
    # Needs to be resolved by defragmenter
    block_ids.apply_(model_runner.defragmenter.resolve)

    # Calculate the slot mapping and fill with padding
    slot_mapping = block_ids * self.block_size + clamped_positions % self.block_size
    dummy_slots = itertools.cycle(range(model_runner._PAD_SLOT_ID, model_runner._PAD_SLOT_ID + self.block_size))
    slot_mapping[num_seq:].apply_(lambda _, ds=dummy_slots: next(ds))
    # Slot mapping needs to be int64 (long) type
    slot_mapping = slot_mapping.to(torch.int64)

    block_list, block_groups, block_usage = \
        model_runner.get_habana_paged_attn_buffers(
            block_tables_list,
            slot_mapping.tolist(),
            batch_size
        )

    block_list_device = async_h2d_copy(block_list, device=self.device)
    block_usage_device = async_h2d_copy(block_usage, device=self.device)
    block_groups_device = async_h2d_copy(block_groups, device=self.device)
    slot_mapping_device = async_h2d_copy(slot_mapping, device=self.device)

    common_attn_metadata = HPUAttentionMetadataV1.make_decode_metadata(
        block_list=block_list_device,
        block_usage=block_usage_device,
        block_groups=block_groups_device,
        input_positions=None,
        slot_mapping=slot_mapping_device,
        block_size=self.block_size,
        window_block_list=None,
        window_block_usage=None,
        window_block_groups=None,
        chunked_block_list=None,
        chunked_block_usage=None,
        chunked_block_groups=None,
    )

    return common_attn_metadata

prepare_inputs

prepare_inputs(
    common_attn_metadata,
    spec_decode_metadata: SpecDecodeMetadata,
    sampled_token_ids: list[list[int]],
)
Source code in vllm_gaudi/v1/spec_decode/hpu_eagle.py
def prepare_inputs(
    self,
    common_attn_metadata,
    spec_decode_metadata: SpecDecodeMetadata,
    sampled_token_ids: list[list[int]],
):
    assert spec_decode_metadata is not None
    num_draft_tokens = \
        spec_decode_metadata.num_draft_tokens
    max_num_draft_tokens = max(num_draft_tokens)

    num_picked_token_indices = []
    last_token_indices = []
    starting_index = 0
    num_rejected_tokens = [
        n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens)
    ]
    for i, n in enumerate(num_draft_tokens):
        r = num_rejected_tokens[i]
        step = max_num_draft_tokens + 1
        for j in range(step):
            if j == n - r:
                last_token_indices.append(starting_index + j)
            if j < n + 1 - r:
                num_picked_token_indices.append(starting_index + j)
            else:
                num_picked_token_indices.append(-1)
        starting_index += step
    hidden_states_indices = torch.tensor(num_picked_token_indices, device=self.device)
    last_token_indices = torch.tensor(last_token_indices, device=self.device)
    return common_attn_metadata, hidden_states_indices, last_token_indices

propose

propose(
    target_token_ids,
    target_positions,
    target_hidden_states,
    last_token_indices,
    common_attn_metadata,
    block_table_cpu_tensor,
    model_runner,
)
Source code in vllm_gaudi/v1/spec_decode/hpu_eagle.py
def propose(
    self,
    # [virtual_batch_size, seq_len]
    target_token_ids,
    # [virtual_batch_size, seq_len]
    target_positions,
    # [virtual_batch_size, seq_len, hidden_size]
    target_hidden_states,
    # [batch_size]
    last_token_indices,
    common_attn_metadata,
    # [num_seq, total_blocks]
    block_table_cpu_tensor,
    model_runner,
):
    # For decode, the virtual batch_size is batch size * num_tokens
    # and the seq_len is always 1
    batch_size = last_token_indices.shape[0]

    if self.method == "eagle3":
        assert isinstance(self.model.model, Eagle3LlamaForCausalLM)
        target_hidden_states = \
            self.model.model.combine_hidden_states(
                target_hidden_states)
        assert target_hidden_states.shape[-1] == self.hidden_size

    ret_hidden_states = self.model(
        input_ids=target_token_ids,
        positions=target_positions,
        hidden_states=target_hidden_states,
        inputs_embeds=None,
        attn_metadata=common_attn_metadata,
    )

    # All MTP related method names are now unified to "mtp"
    if self.method == "mtp":
        last_hidden_states = ret_hidden_states
        hidden_states = last_hidden_states
    else:
        last_hidden_states, hidden_states = ret_hidden_states
    last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
    sample_hidden_states = last_hidden_states[last_token_indices]
    logits = self.model.compute_logits(sample_hidden_states)

    # Early exit if there is only one draft token to be generated.
    if self.num_speculative_tokens == 1:
        draft_token_ids = logits.argmax(dim=-1)
        return draft_token_ids.view(-1, 1)

    # [num_tokens, 1]
    target_positions = target_positions.view(-1)
    # [batch_size]
    positions = target_positions[last_token_indices]
    if self.method == "mtp":
        hidden_states = target_hidden_states.view(-1, target_hidden_states.shape[-1])
    else:
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

    # [batch_size, hidden_size]
    hidden_states = hidden_states[last_token_indices]

    # The first draft tokens
    draft_token_ids = logits.argmax(dim=-1)
    # Generate the remaining draft tokens.
    draft_token_ids_list = [draft_token_ids]

    # Positions used by prepare_attn_metadata needs to be cpu because
    # compile only mode for warmup will not do any real computations
    target_positions_cpu = target_positions.cpu()
    positions_cpu = target_positions_cpu[last_token_indices.cpu()]

    # Decode 1 token each time
    for token_index in range(self.num_speculative_tokens - 1):
        # Update the inputs.
        # cast to int32 is crucial when eagle model is compiled.
        # tensor.argmax() returns int64 by default.
        # [batch_size]
        input_ids = draft_token_ids_list[-1].int()

        positions += 1
        exceeds_max_model_len = positions >= self.max_model_len
        clamped_positions = torch.where(exceeds_max_model_len, 0, positions)

        # Prepare the attn metadata
        positions_cpu += 1
        attn_metadata = self.prepare_attn_metadata(block_table_cpu_tensor, positions_cpu, model_runner)

        # [batch_size, 1]
        input_ids = input_ids.view(-1, 1)
        # [batch_size, 1]
        input_positions = clamped_positions.view(-1, 1)
        # [batch_size, 1, hidden_size]
        input_hidden_states = hidden_states.view(-1, 1, hidden_states.shape[-1])
        inputs_embeds = None

        ret_hidden_states = self.model(
            input_ids=input_ids,
            positions=input_positions,
            hidden_states=input_hidden_states,
            inputs_embeds=inputs_embeds,
            attn_metadata=attn_metadata,
        )
        if self.method == "mtp":
            last_hidden_states = ret_hidden_states
            hidden_states = ret_hidden_states
        else:
            last_hidden_states, hidden_states = ret_hidden_states

        # The shape of the returned hidden_states and last_hidden_states:
        # [batch_size, 1, hidden_size]
        # viewed to: [batch_size, hidden_size]
        last_hidden_states = last_hidden_states.view(-1, last_hidden_states.shape[-1])
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

        hidden_states = hidden_states[:batch_size]
        logits = self.model.compute_logits(last_hidden_states[:batch_size])
        draft_token_ids = logits.argmax(dim=-1)
        draft_token_ids_list.append(draft_token_ids)

    # [batch_size, num_speculative_tokens]
    draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
    return draft_token_ids