Skip to content

vllm_gaudi.models.utils

_merge_multimodal_embeddings

_merge_multimodal_embeddings(
    inputs_embeds: Tensor,
    multimodal_embeddings: NestedTensors,
    is_multimodal: Tensor,
) -> Tensor

Merge multimodal_embeddings into inputs_embeds by overwriting the positions in inputs_embeds corresponding to placeholder tokens in input_ids.

Note

This updates inputs_embeds in place.

Source code in vllm_gaudi/models/utils.py
def _merge_multimodal_embeddings(
    inputs_embeds: torch.Tensor,
    multimodal_embeddings: NestedTensors,
    is_multimodal: torch.Tensor,
) -> torch.Tensor:
    """
    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
    positions in ``inputs_embeds`` corresponding to placeholder tokens in
    ``input_ids``.

    Note:
        This updates ``inputs_embeds`` in place.
    """
    if len(multimodal_embeddings) == 0:
        return inputs_embeds

    import habana_frameworks.torch.core as htcore
    htcore.mark_step()

    mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
    input_dtype = inputs_embeds.dtype

    if inputs_embeds.ndim == 3 and mm_embeds_flat.ndim == 2:
        original_shape = inputs_embeds.shape
        inputs_embeds = inputs_embeds.view(-1, inputs_embeds.shape[-1])
        if is_multimodal.dtype == torch.bool:
            inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
        else:
            inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat.to(dtype=input_dtype))
        return inputs_embeds.view(original_shape)
    if is_multimodal.dtype != torch.bool:
        return inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat)
    try:
        # For debugging
        # inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)

        # NOTE: This can avoid D2H sync (#22105), but fails to
        # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
        inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
    except RuntimeError as e:
        num_actual_tokens = len(mm_embeds_flat)
        num_expected_tokens = is_multimodal.sum().item()

        if num_actual_tokens != num_expected_tokens:
            expr = _embedding_count_expression(multimodal_embeddings)

            raise ValueError(f"Attempted to assign {expr} = {num_actual_tokens} "
                             f"multimodal tokens to {num_expected_tokens} placeholders") from e

        raise ValueError("Error during masked scatter operation") from e

    return inputs_embeds