vllm.v1.attention.ops.triton_attention_helpers ¶
Shared @triton.jit helpers used by the unified attention kernel and reduce_segments.
These are plain attention-loop helpers — mask building, ALiBi / QQ-bias score post-processing, online-softmax bookkeeping, tile-loop bounds, sequence lookup — extracted so the 2D and 3D paths of the unified kernel (and any future consumer) share a single implementation.
Functions:
-
apply_alibi_to_score–Add the ALiBi positional bias (linear or sqrt variant) to S in-place.
-
apply_softcap–Softcap (aka tanh-style clamp) used to bound attention scores.
-
cdiv_fn–Ceiling division. Kept as a helper to keep kernel bodies terse.
-
compute_kv_seq_mask–Build the KV mask for one tile.
-
compute_tile_loop_bounds–Compute the tile-loop bounds
(loop_lo, loop_hi)and the -
find_seq_idx–Binary search over the cumulative query-length prefix.
-
init_softmax_M–Initial row-max
Mfor the online softmax. -
load_qq_bias_tile–Load the qq-bias slice for keys that correspond to query rows.
-
resolve_seq_and_query_len–Resolve the (sequence, q-block-within-sequence) pair and load the
-
softmax_step–Online softmax update for one tile.
-
store_segm_reduce_scalars–Store per-segment
MandLforreduce_segmentsto
apply_alibi_to_score(S, alibi_slope, seq_offset, context_len, query_pos, USE_ALIBI_SQRT) ¶
Add the ALiBi positional bias (linear or sqrt variant) to S in-place.
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
apply_softcap(S, x) ¶
Softcap (aka tanh-style clamp) used to bound attention scores.
x * tanh(S / x) rewritten to avoid a direct tanh call.
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
cdiv_fn(x, y) ¶
compute_kv_seq_mask(query_abs_pos, seq_offset, seq_idx, seq_len, mm_prefix_range_ptr, SLIDING_WINDOW, USE_MM_PREFIX, MAX_MM_RANGES, USE_CAUSAL=True, USE_PER_SEQ_CAUSAL=False, per_seq_causal_ptr=None, CHUNK_LOOKBACK=-1, CHUNK_SIZE=-1) ¶
Build the KV mask for one tile.
Causal (key <= query) by default; AND-ed with either chunked attention (CHUNK_LOOKBACK >= 0) or sliding window (SLIDING_WINDOW > 0); OR-ed with the bidirectional ranges from mm_prefix_range when PrefixLM / multimodal attention is active. Order matches FlexAttention: (causal AND window) OR mm_prefix. Chunked attention takes precedence over sliding window when both are non-default — the launcher zeros CHUNK_LOOKBACK whenever sliding window is disabled.
When USE_PER_SEQ_CAUSAL is set, each sequence carries its own causal flag via per_seq_causal_ptr; non-causal sequences use a simple key < seq_len bound instead. USE_CAUSAL=False disables causal masking entirely.
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 | |
compute_tile_loop_bounds(context_len, seq_len, cur_batch_query_len, q_block_local_idx, segm_idx_or_0, tiles_per_segment_or_0, TILE_SIZE, BLOCK_M, BLOCK_Q, num_queries_per_kv, SLIDING_WINDOW, USE_MM_PREFIX, IS_3D, USE_CAUSAL=True, USE_PER_SEQ_CAUSAL=False, CHUNK_LOOKBACK=-1, CHUNK_SIZE=-1) ¶
Compute the tile-loop bounds (loop_lo, loop_hi) and the derived max_seq_prefix_len used for per-tile masking.
Combines three concerns into one helper:
- Longest prefix spanned by any query token in this q-block. Clamped to
seq_len(causal) or extended to it when mm_prefix is active or non-causal sequences need the full sequence. - Sliding-window pruning: narrows
[tile_start, tile_end)to only tiles that can contain an allowed key under SWA. For non-causal sequences, the window extends in both directions. - 3D scoping: when
IS_3Dis True, further narrows to the segment's slice via(segm_idx * tiles_per_segment, (segm_idx + 1) * tiles_per_segment).
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | |
find_seq_idx(query_start_len_ptr, target_idx, num_seqs, BLOCK_Q, use_q_block_mode) ¶
Binary search over the cumulative query-length prefix.
When use_q_block_mode is True, the prefix values are reshaped into units of BLOCK_Q plus one entry per boundary — matching the q-block grid laid out by the attention kernels. When False we search the plain cumulative-length prefix (used by reduce_segments which iterates over raw query tokens).
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
init_softmax_M(sink_ptr, query_offset_1, query_mask_1, segm_idx_or_0, BLOCK_M, USE_SINKS, IS_3D) ¶
Initial row-max M for the online softmax.
Without sinks: -inf. With sinks: load the per-head sink bias once. In 3D mode only segment 0 loads — reduce_segments adds the sink contribution exactly once across segments, so other segments must start from -inf.
segm_idx_or_0 is the 3D segment index or 0 for 2D (caller passes 0 when IS_3D is False).
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
load_qq_bias_tile(qq_bias_row_ptrs, seq_offset, context_len, qq_bias_stride_0) ¶
Load the qq-bias slice for keys that correspond to query rows.
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
resolve_seq_and_query_len(query_start_len_ptr, seq_lens_ptr, q_block_global_idx, num_seqs, BLOCK_Q) ¶
Resolve the (sequence, q-block-within-sequence) pair and load the per-sequence lengths.
Shared across every attention kernel — the q_block_global_idx program id indexes into the flattened (seq, q_block_in_seq) space, and a binary search over query_start_len_ptr recovers the (seq, local-q-block) pair.
Returns (seq_idx, q_block_local_idx, cur_batch_in_all_start_index, cur_batch_query_len, seq_len). Callers must still early-return when q_block_local_idx * BLOCK_Q >= cur_batch_query_len (Triton helpers cannot return from the caller).
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
softmax_step(S, M, L) ¶
Online softmax update for one tile.
Returns (M_new, L_new, P, alpha). Caller is responsible for rescaling its accumulator(s) by alpha[:, None] — done outside so kernels with a different number / shape of accumulators can reuse the same step.
Source code in vllm/v1/attention/ops/triton_attention_helpers.py
store_segm_reduce_scalars(segm_max_ptr, segm_expsum_ptr, query_offset_0, query_offset_1, segm_idx, M, L, query_mask_0, query_mask_1, num_query_heads, NUM_SEGMENTS_PER_SEQ) ¶
Store per-segment M and L for reduce_segments to combine into the final softmax.
Shared across every 3D attention epilogue; the per-token output stripes are mode-specific (flat / 2-stream split / 4-stream split) and stay inlined.