Skip to content

vllm_omni.diffusion.models.internvla_a1

InternVLA-A1 diffusion model components.

Modules:

Name Description
adapter_qwen3_vl
config
cosmos_ci_torch
model_cosmos
model_internvla_a1
pipeline_internvla_a1

InternVLAA1

Bases: Module

action_in_proj instance-attribute

action_in_proj = Linear(max_action_dim, hidden_size)

action_out_proj instance-attribute

action_out_proj = Linear(hidden_size, max_action_dim)

action_time_mlp_in instance-attribute

action_time_mlp_in = Linear(2 * hidden_size, hidden_size)

action_time_mlp_out instance-attribute

action_time_mlp_out = Linear(hidden_size, hidden_size)

config instance-attribute

config = config

cosmos instance-attribute

cosmos = ImageTokenizer(
    checkpoint_enc=str(cosmos_encoder_path),
    checkpoint_dec=str(cosmos_decoder_path),
    device=device,
)

cosmos_in_proj instance-attribute

cosmos_in_proj = Conv2d(
    vae_dim, hidden_size, kernel_size=1, stride=1, padding=0
)

cosmos_out_layer_norm instance-attribute

cosmos_out_layer_norm = LayerNorm(hidden_size)

cosmos_out_proj instance-attribute

cosmos_out_proj = Linear(hidden_size, vae_dim)

downsample_conv instance-attribute

downsample_conv = Conv2d(
    hidden_size,
    hidden_size,
    kernel_size=ds,
    stride=ds,
    padding=0,
)

qwen3_vl_with_expert instance-attribute

qwen3_vl_with_expert = Qwen3VLWithExpertModel(
    vlm_config, action_expert_config, precision=dtype
)

state_proj instance-attribute

state_proj = Linear(max_state_dim, hidden_size)

upsample_conv instance-attribute

upsample_conv = ConvTranspose2d(
    hidden_size,
    hidden_size,
    kernel_size=ds,
    stride=ds,
    padding=0,
)

denoise_step

denoise_step(
    state: Tensor,
    prefix_pad_masks: Tensor,
    past_key_values: Any,
    max_prefix_position_ids: Tensor,
    x_t: Tensor,
    timestep: Tensor,
) -> Tensor

denoise_step_optimized

denoise_step_optimized(
    suffix_static: SuffixStaticContext,
    past_key_values: Any,
    x_t: Tensor,
    timestep: Tensor,
) -> Tensor

embed_middle

embed_middle(
    images: Tensor, img_masks: Tensor
) -> tuple[Tensor, Tensor, Tensor]

embed_prefix

embed_prefix(
    pixel_values: Tensor,
    image_grid_thw: Tensor,
    lang_tokens: Tensor,
    lang_masks: Tensor,
) -> tuple[Tensor, Tensor, Tensor]

embed_suffix

embed_suffix(
    state: Tensor, noisy_actions: Tensor, timestep: Tensor
) -> tuple[Tensor, Tensor, Tensor]

get_cosmos_features

get_cosmos_features(images: Tensor) -> Tensor

get_position_ids

get_position_ids(
    lang_tokens: Tensor,
    image_grid_thw: Tensor | None,
    pad_masks: Tensor,
) -> tuple[Tensor, Any]

prepare_suffix_static_context

prepare_suffix_static_context(
    state: Tensor,
    prefix_pad_masks: Tensor,
    max_prefix_position_ids: Tensor,
) -> SuffixStaticContext

sample_actions

sample_actions(
    images: Tensor,
    img_masks: Tensor,
    pixel_values: Tensor,
    image_grid_thw: Tensor,
    lang_tokens: Tensor,
    lang_masks: Tensor,
    state: Tensor,
    *,
    noise: Tensor | None = None,
    num_steps: int | None = None,
    decode_image: bool = False,
) -> tuple[Tensor, Tensor | None]

sample_noise

sample_noise(
    shape: tuple[int, ...], device: device
) -> Tensor

set_attention_implementation

set_attention_implementation(
    attn_implementation: str,
) -> None

InternVLAA1Config dataclass

Standalone-compatible InternVLA-A1 config with a few fake-smoke defaults.

action_expert_variant class-attribute instance-attribute

action_expert_variant: str = 'qwen3_28l'

attn_implementation class-attribute instance-attribute

attn_implementation: str = 'eager'

chunk_size class-attribute instance-attribute

chunk_size: int = 50

compile_mode class-attribute instance-attribute

compile_mode: str = 'max-autotune'

compile_model class-attribute instance-attribute

compile_model: bool = False

device class-attribute instance-attribute

device: str = 'cuda'

dtype class-attribute instance-attribute

dtype: str = 'bfloat16'

empty_cameras class-attribute instance-attribute

empty_cameras: int = 0

enable_regional_compile class-attribute instance-attribute

enable_regional_compile: bool = False

enable_suffix_static_context_optimization class-attribute instance-attribute

enable_suffix_static_context_optimization: bool = False

freeze_vision_encoder class-attribute instance-attribute

freeze_vision_encoder: bool = False

gradient_checkpointing class-attribute instance-attribute

gradient_checkpointing: bool = False

hidden_size class-attribute instance-attribute

hidden_size: int = 128

image_history class-attribute instance-attribute

image_history: int = 2

image_resolution class-attribute instance-attribute

image_resolution: tuple[int, int] = (224, 224)

input_features class-attribute instance-attribute

input_features: dict[str, Any] = field(default_factory=dict)

intermediate_size class-attribute instance-attribute

intermediate_size: int = 256

lambda_gen class-attribute instance-attribute

lambda_gen: float = 0.01

max_action_dim class-attribute instance-attribute

max_action_dim: int = 32

max_period class-attribute instance-attribute

max_period: float = 4.0

max_state_dim class-attribute instance-attribute

max_state_dim: int = 32

min_period class-attribute instance-attribute

min_period: float = 0.004

n_action_steps class-attribute instance-attribute

n_action_steps: int = 50

num_attention_heads class-attribute instance-attribute

num_attention_heads: int = 4

num_cameras class-attribute instance-attribute

num_cameras: int = 3

num_hidden_layers class-attribute instance-attribute

num_hidden_layers: int = 2

num_inference_steps class-attribute instance-attribute

num_inference_steps: int = 10

output_features class-attribute instance-attribute

output_features: dict[str, Any] = field(
    default_factory=dict
)

pixel_feature_dim class-attribute instance-attribute

pixel_feature_dim: int = 48

qwen3_vl_variant class-attribute instance-attribute

qwen3_vl_variant: str = 'qwen3_vl_28l'

regional_compile_dynamic class-attribute instance-attribute

regional_compile_dynamic: bool = True

scale_factor class-attribute instance-attribute

scale_factor: int = 8

time_sampling_beta_alpha class-attribute instance-attribute

time_sampling_beta_alpha: float = 1.5

time_sampling_beta_beta class-attribute instance-attribute

time_sampling_beta_beta: float = 1.0

time_sampling_offset class-attribute instance-attribute

time_sampling_offset: float = 0.001

time_sampling_scale class-attribute instance-attribute

time_sampling_scale: float = 0.999

tokenizer_max_length class-attribute instance-attribute

tokenizer_max_length: int = 48

train_expert_only class-attribute instance-attribute

train_expert_only: bool = False

train_vlm_only class-attribute instance-attribute

train_vlm_only: bool = False

type class-attribute instance-attribute

type: str = 'internvla_a1'

vocab_size class-attribute instance-attribute

vocab_size: int = 256

from_model_config classmethod

from_model_config(
    model_config: dict[str, Any] | None,
) -> InternVLAA1Config

from_pretrained classmethod

from_pretrained(
    checkpoint_dir: str | Path,
) -> InternVLAA1Config

InternVLAA1Pipeline

Bases: Module, DiffusionPipelineProfilerMixin

InternVLA-A1 pipeline wrapper for the policy implementation.

config instance-attribute

config = _build_config(od_config)

enable_warmup instance-attribute

enable_warmup = (
    bool(enable_warmup)
    if isinstance(enable_warmup, bool)
    else False
)

model_dir instance-attribute

model_dir = model

od_config instance-attribute

od_config = od_config

policy instance-attribute

policy = _initialize_policy()

prefix instance-attribute

prefix = prefix

processor_model_name instance-attribute

processor_model_name = str(
    get("processor_model_name", DEFAULT_QWEN3_VL_MODEL)
)

strict_load instance-attribute

strict_load = bool(get('strict_load', False))

forward

has_real_checkpoint

has_real_checkpoint() -> bool

runtime_mode

runtime_mode() -> str

InternVLAA1Policy

Bases: Module

config instance-attribute

config = config

input_builder instance-attribute

input_builder = Qwen3VLInputBuilder(
    processor_model_name=processor_model_name,
    max_length=tokenizer_max_length,
)

model instance-attribute

model = InternVLAA1(
    config,
    cosmos_encoder_path=cosmos_encoder_path,
    cosmos_decoder_path=cosmos_decoder_path,
)

forward

forward(
    batch: dict[str, Any],
    *,
    noise: Tensor | None = None,
    decode_image: bool = False,
) -> tuple[Tensor, Tensor | None]

from_pretrained classmethod

from_pretrained(
    checkpoint_dir: str | Path,
    *,
    config: InternVLAA1Config | None = None,
    processor_model_name: str = DEFAULT_QWEN3_VL_MODEL,
    strict: bool = False,
) -> InternVLAA1Policy

prepare_state

prepare_state(batch: dict[str, Tensor]) -> Tensor

to

to(*args, **kwargs)

InternVLAA1TrainMetadata dataclass

action_mode class-attribute instance-attribute

action_mode: str = 'delta'

processor_model_name class-attribute instance-attribute

processor_model_name: str = DEFAULT_QWEN3_VL_MODEL

from_pretrained classmethod

from_pretrained(
    checkpoint_dir: str | Path,
) -> InternVLAA1TrainMetadata

get_internvla_a1_post_process_func

get_internvla_a1_post_process_func(
    od_config: OmniDiffusionConfig,
)