@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
SM_SCALE: tl.constexpr,
L,
Out,
stride_qz: tl.int64,
stride_qh: tl.int64,
stride_qm: tl.int64,
stride_qk: tl.int64,
stride_kz: tl.int64,
stride_kh: tl.int64,
stride_kn: tl.int64,
stride_kk: tl.int64,
stride_vz: tl.int64,
stride_vh: tl.int64,
stride_vk: tl.int64,
stride_vn: tl.int64,
stride_oz: tl.int64,
stride_oh: tl.int64,
stride_om: tl.int64,
stride_on: tl.int64,
stride_bz: tl.int64,
stride_bh: tl.int64,
stride_bm: tl.int64,
stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q,
cu_seqlens_k,
philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base,
encoded_softmax,
alibi_slopes,
HQ: tl.constexpr,
HK: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr,
USE_BIAS: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_ALIBI: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
):
if o_descale_ptr is not None:
o_descale = tl.load(o_descale_ptr)
start_m: tl.int64 = tl.program_id(0)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be
# too small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
continue_condition = False
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
if continue_condition:
# Now we compute whether we need to exit early due to causal
# masking. This is because for seqlen_q > seqlen_k, M rows of the
# attn scores are completely masked, resulting in 0s written to the
# output, and inf written to LSE. We don't need to do any GEMMs in
# this case. This block of code determines what N is, and if this
# WG is operating on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if (IS_CAUSAL):
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which
# means the causal mask boundary is bottom right aligned, and
# ends at either the top edge (seqlen_q < seqlen_k) or left
# edge. This captures the decrease in n_blocks if we have a
# rectangular attn matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all
# n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this
# WG is part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to(
[BLOCK_M, BLOCK_DMODEL])
# We still need to write 0s to the result
tl.store(o_ptrs, acc, mask=o_ptrs_mask)
# The tensor allocated for L is based on MAX_SEQLENS_Q as
# that is statically known.
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this from qk which makes it -inf, such that
# exp(qk - inf) = 0 for these masked blocks.
l_value = tl.full([BLOCK_M],
value=float("inf"),
dtype=tl.float32)
l_ptrs_mask = offs_m < MAX_SEQLENS_Q
tl.store(l_ptrs, l_value, mask=l_ptrs_mask)
# TODO: Should dropout and return encoded softmax be
# handled here too?
continue_condition = False
# return
if continue_condition:
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL
!= BLOCK_DMODEL)
# Compute pointers for all the tensors used in this kernel.
q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm)
q_ptrs = (q_offset + offs_m[:, None] * stride_qm +
offs_d[None, :] * stride_qk)
k_offset = (K + off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn)
k_ptrs = (k_offset + offs_d[:, None] * stride_kk +
offs_n[None, :] * stride_kn)
v_offset = (V + off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk)
v_ptrs = (v_offset + offs_n[:, None] * stride_vk +
offs_d[None, :] * stride_vn)
# Compute pointers for all scale tensors used in this kernel.
IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & (
not IS_EIGHT_BIT_KV)
if IS_EIGHT_BIT:
if k_descale_has_singleton:
k_descale_ptrs = k_descale_ptr
else:
k_descale_ptrs = k_descale_ptr + off_h_k
if v_descale_has_singleton:
v_descale_ptrs = v_descale_ptr
else:
v_descale_ptrs = v_descale_ptr + off_h_k
if not IS_EIGHT_BIT_KV:
if q_descale_has_singleton:
q_descale_ptrs = q_descale_ptr
else:
q_descale_ptrs = q_descale_ptr + off_h_q
if USE_P_SCALE:
if p_descale_has_singleton:
p_scale_ptrs = p_scale_ptr
p_descale_ptrs = p_descale_ptr
else:
p_scale_ptrs = p_scale_ptr + off_h_q
p_descale_ptrs = p_descale_ptr + off_h_q
if USE_BIAS:
bias_offset = off_h_q * stride_bh
bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm +
offs_n[None, :] * stride_bn)
else:
bias_ptrs = None
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
batch_philox_offset = 0
# We can ask to return the dropout mask without doing any
# dropout. In this case, we return an invalid pointer so
# indicate the mask is not valid.
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_base = (encoded_softmax +
off_h_q * seqlen_q * seqlen_k)
encoded_sm_ptrs = (encoded_sm_base +
offs_m[:, None] * seqlen_k +
offs_n[None, :])
else:
encoded_sm_ptrs = None
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do
# not have native e^x support in HW.
QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q_ptrs_mask = offs_m[:, None] < seqlen_q
if USE_PADDED_HEAD:
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
if IS_EIGHT_BIT:
k_descale = tl.load(k_descale_ptrs)
v_descale = tl.load(v_descale_ptrs)
q_descale = None if IS_EIGHT_BIT_KV else tl.load(
q_descale_ptrs)
if USE_P_SCALE:
p_scale = tl.load(p_scale_ptrs)
p_descale = tl.load(p_descale_ptrs)
else:
p_scale = None
p_descale = None
else:
q_descale = None
k_descale = None
v_descale = None
p_scale = None
p_descale = None
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked
# blocks. Additionally there might be one more due to
# dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an
# additional block. In this case we might exceed n_blocks so
# pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false
# regardless of its actual value because there is no masking.
# Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
k_ptrs,
v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m,
seqlen_k,
seqlen_q,
philox_seed,
batch_philox_offset,
encoded_sm_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, SHOULD_MASK_STEPS, ...
SHOULD_PRE_LOAD_V,
False,
SHOULD_RETURN_ENCODED_SOFTMAX,
USE_PADDED_HEAD,
IS_ACTUAL_BLOCK_DMODEL,
QK_SCALE,
IS_EIGHT_BIT_GEMM,
USE_P_SCALE,
IS_EIGHT_BIT_KV,
QUANT_DTYPE)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if (masked_blocks > 0):
if IS_CAUSAL:
offs_n_causal = offs_n + (seqlen_q - seqlen_k)
else:
offs_n_causal = 0
k_ptrs += n_full_blocks * BLOCK_N * stride_kn
v_ptrs += n_full_blocks * BLOCK_N * stride_vk
if USE_BIAS:
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
k_ptrs,
v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m,
seqlen_k,
seqlen_q,
philox_seed,
batch_philox_offset,
encoded_sm_ptrs,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, SHOULD_MASK_STEPS, ...
SHOULD_PRE_LOAD_V,
True,
SHOULD_RETURN_ENCODED_SOFTMAX,
USE_PADDED_HEAD,
IS_ACTUAL_BLOCK_DMODEL,
QK_SCALE,
IS_EIGHT_BIT_GEMM,
USE_P_SCALE,
IS_EIGHT_BIT_KV,
QUANT_DTYPE)
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV:
if USE_P_SCALE:
acc *= p_descale
acc *= v_descale
# epilogue
# This helps the compiler do Newton Raphson on l_i vs on acc
# which is much larger.
l_recip = 1 / l_i[:, None]
acc = acc * l_recip
# If seqlen_q > seqlen_k but the delta is not a multiple of
# BLOCK_M, then we have one block with a row of all NaNs which
# come from computing softmax over a row of all
# -infs (-inf - inf = NaN). We check for that here and store 0s
# where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102
if o_descale_ptr is not None:
acc = quant_fp8(acc, o_descale)
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if (causal_start_idx > start_m_idx
and causal_start_idx < end_m_idx):
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :])
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc,
z.to(acc.type.element_ty))
# write back LSE
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# If seqlen_q not multiple of BLOCK_M, we need to mask out the
# last few rows. This is only true for the last M block.
# For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M, ),
BLOCK_M - overflow_size,
dtype=tl.int32)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
if overflow_size > 0:
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
if USE_PADDED_HEAD:
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)