class DFlashCudaGraphManager(CudaGraphManager):
"""DFlash CudaGraphManager for the parallel-drafting query forward,
building its own attention metadata from scratch."""
def __init__(self, *args, causal: bool = False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.causal = causal
def capture(
self,
forward_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
max_model_len: int,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
def create_forward_fn(
desc: BatchExecutionDescriptor,
warmup: bool,
) -> tuple[Callable[[CUDAGraphMode], None], AttentionState]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
attn_state = _prepare_dflash_inputs_to_capture(
num_reqs,
num_tokens,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
max_model_len,
skip_attn=(desc.cg_mode == CUDAGraphMode.PIECEWISE),
causal=self.causal,
)
attn_metadata, slot_mappings = attn_state
fwd = lambda cg_mode: forward_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cg_mode,
)
return fwd, attn_state
super().capture(create_forward_fn, progress_bar_desc)