Skip to content

vllm_omni.diffusion.attention.backends.utils.piecewise_attn

Piecewise attention for mixed causal / full (bidirectional) masks.

Dispatches each segment as a separate attention call whose causal flag follows FlashAttention's bottom-right convention (K[:e] is attended by Q[s:e], with causal alignment anchored at the bottom-right corner).

Per segment
  • causal segment [s, e): attn(Q[:, s:e], K[:, :e], V[:, :e], causal=True)
  • full-attn span [a, e): attn(Q[:, a:e], K[:, :e], V[:, :e], causal=False)

Segment

Bases: NamedTuple

end instance-attribute

end: int

mode instance-attribute

mode: Literal['causal', 'full']

start instance-attribute

start: int

build_segments

build_segments(full_attn_spans, query_offset, query_len)

full_attn_spans: list of (start, end) half-open spans in global coordinates query_offset: starting position of query in the global sequence query_len: length of the query

return

List[Segment] in global coordinates, clipped to [query_offset, query_offset + query_len)

piecewise_attn

piecewise_attn(
    query,
    key,
    value,
    full_attn_spans: list[list[tuple[int, int]]],
    softmax_scale: float,
    attn_func,
)