Skip to content

vllm.lora.ops.xla_ops.lora_ops

bgmv_expand

bgmv_expand(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    add_inputs: bool = True,
)

Parameters:

Name Type Description Default
inputs Tensor

Input tensor of shape [num_tokens, hidden_size].

required
lora_b_weights Tensor

LoRA weights of shape [num_loras, lora_rank, hidden_size].

required
output_tensor Tensor

output tensor of shape [num_tokens, hidden_size * num_slices].

required
lora_indices_tensor Tensor

Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token.

required
add_inputs bool

Whether or not to add the input tensor to the output tensor.

True
Source code in vllm/lora/ops/xla_ops/lora_ops.py
def bgmv_expand(inputs: torch.Tensor,
                lora_b_weights: torch.Tensor,
                output_tensor: torch.Tensor,
                lora_indices_tensor: torch.Tensor,
                add_inputs: bool = True):
    """
    Args:
        inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].

        lora_b_weights (torch.Tensor): LoRA weights of shape 
            [num_loras, lora_rank, hidden_size].

        output_tensor (torch.Tensor): output tensor of shape 
            [num_tokens, hidden_size * num_slices].

        lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] 
            indicating which LoRA matrix to use for each token.
        add_inputs (bool): Whether or not to add the input tensor to the output 
            tensor.
    """

    outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
    n_tokens = outputs.size(0)

    limit = output_tensor.shape[0]
    if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
        limit = 1

    outputs = torch.cat(
        (outputs,
         torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]),
                     device=outputs.device)),
        dim=1)

    if add_inputs:
        return output_tensor + outputs[:limit, :]
    else:
        return outputs[:limit, :]

bgmv_expand_slice

bgmv_expand_slice(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = True,
)

Parameters:

Name Type Description Default
inputs Tensor

Input tensor of shape [num_tokens, hidden_size].

required
lora_b_weights Tensor

LoRA weights of shape [num_loras, lora_rank, hidden_size].

required
output_tensor Tensor

output tensor of shape [num_tokens, hidden_size * num_slices].

required
lora_indices_tensor Tensor

Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token.

required
add_inputs bool

Whether or not to add the input tensor to the output tensor.

True
Source code in vllm/lora/ops/xla_ops/lora_ops.py
def bgmv_expand_slice(inputs: torch.Tensor,
                      lora_b_weights: torch.Tensor,
                      output_tensor: torch.Tensor,
                      lora_indices_tensor: torch.Tensor,
                      slice_offset: int,
                      slice_size: int,
                      add_inputs: bool = True):
    """
    Args:
        inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].

        lora_b_weights (torch.Tensor): LoRA weights of shape 
            [num_loras, lora_rank, hidden_size].

        output_tensor (torch.Tensor): output tensor of shape 
            [num_tokens, hidden_size * num_slices].

        lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] 
            indicating which LoRA matrix to use for each token.
        add_inputs (bool): Whether or not to add the input tensor to the output 
            tensor.
    """
    outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
    n_tokens = outputs.size(0)

    outputs = torch.cat((
        torch.zeros((n_tokens, slice_offset), device=outputs.device),
        outputs,
        torch.zeros(
            (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)),
            device=outputs.device),
    ),
                        dim=1)

    if add_inputs:
        return output_tensor + outputs
    else:
        return outputs

bgmv_shrink

bgmv_shrink(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    scaling: float = 1.0,
)

Parameters:

Name Type Description Default
inputs Tensor

Input tensor of shape [num_tokens, hidden_size].

required
lora_b_weights Tensor

LoRA weights of shape [num_loras, lora_rank, hidden_size].

required
output_tensor Tensor

(Unused) output tensor (placeholder).

required
lora_indices_tensor Tensor

Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token.

required
scaling float

Scalar multiplier applied to the output.

1.0
Source code in vllm/lora/ops/xla_ops/lora_ops.py
def bgmv_shrink(inputs: torch.Tensor,
                lora_b_weights: torch.Tensor,
                output_tensor: torch.Tensor,
                lora_indices_tensor: torch.Tensor,
                scaling: float = 1.0):
    """
    Args:
        inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
        lora_b_weights (torch.Tensor): LoRA weights of shape 
            [num_loras, lora_rank, hidden_size].
        output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
        lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] 
            indicating which LoRA matrix to use for each token.
        scaling (float, optional): Scalar multiplier applied to the output.
    """

    return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights,
                                        lora_indices_tensor)