Skip to content

vllm_omni.diffusion.models.dreamzero.state_dreamzero

DreamZero pipeline persistent state.

FRAMES_PER_CHUNK module-attribute

FRAMES_PER_CHUNK = 4

logger module-attribute

logger = getLogger(__name__)

DreamZeroState

Pipeline persistent state across forward() calls.

Lifecycle
  • Created once in DreamZeroPipeline.init()
  • Mutated every forward() call (frame append, KV cache grow)
  • reset() on new session / language change / local_attn_size exceeded

accumulate_frames

accumulate_frames(stitched: ndarray) -> ndarray

Accumulate stitched frames and return multi-frame video.

Parameters:

Name Type Description Default
stitched ndarray

(H, W, C) single frame or (T, H, W, C) multi-frame, already stitched by transform.

required

Returns:

Type Description
ndarray

(T, H, W, C) ndarray. T=1 for first call, T=FRAMES_PER_CHUNK(4) after.

create_kv_caches

create_kv_caches(
    batch_size: int,
    dtype: dtype,
    device: device,
    num_layers: int,
    num_heads: int,
    head_dim: int,
) -> None

Initialize empty KV caches and cross-attention caches.

get_crossattn_caches

get_crossattn_caches(
    is_negative: bool = False,
) -> list[dict[str, bool | Tensor | None]]

Get cross-attention caches for the specified branch.

get_kv_caches

get_kv_caches(is_negative: bool = False) -> list[Tensor]

Get KV caches for the specified branch.

reset

reset() -> None

Clear all state.

should_reset

should_reset(
    text_tokens: Tensor | None,
    num_video_frames: int,
    local_attn_size: int,
) -> bool

Determine if state should be reset before this forward().

update_kv_cache

update_kv_cache(
    layer_index: int,
    updated_kv: Tensor,
    is_negative: bool = False,
) -> None

Update a single layer's KV cache after prefill.