@jax.jit
def _bgmv(
idxs: jax.Array, # (T, ) int32
inputs: jax.Array, # (T, D) model dtype
loras: jax.Array # (N, L, D) model dtype
) -> jax.Array: # (T, L) model dtype
T, D = inputs.shape
N, L, _ = loras.shape
return pl.pallas_call(
kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK),
out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK,
D // DIM_BLOCK_SIZE),
in_specs=[
pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (i, k)),
pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE),
lambda i, j, k, block_idx: (0, j, k)),
],
out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK),
lambda i, j, k, block_idx: (i, j)),
scratch_shapes=[
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32),
pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32)
]),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
name="bgmv")(idxs, inputs, loras)