Merge multimodal_embeddings into inputs_embeds by overwriting the
positions in inputs_embeds corresponding to placeholder tokens in
input_ids.
Note
This updates inputs_embeds in place.
Source code in vllm_gaudi/models/utils.py
| def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
if len(multimodal_embeddings) == 0:
return inputs_embeds
import habana_frameworks.torch.core as htcore
htcore.mark_step()
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
if inputs_embeds.ndim == 3 and mm_embeds_flat.ndim == 2:
original_shape = inputs_embeds.shape
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.shape[-1])
if is_multimodal.dtype == torch.bool:
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
else:
inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat.to(dtype=input_dtype))
return inputs_embeds.view(original_shape)
if is_multimodal.dtype != torch.bool:
return inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat)
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders") from e
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
|