Skip to content

vllm_omni.diffusion.models.gr00t.policy

Gr00tPolicy

Core policy class for Gr00t model inference.

This policy handles the end-to-end inference pipeline: 1. Validates input observations 2. Processes observations with pretrained VLA processor 3. Runs model inference 4. Decodes and returns actions

The policy expects observations with specific modalities (video, state, language) and returns actions in the format defined by the model's modality configuration.

collate_fn instance-attribute

collate_fn = self.processor.collator

embodiment_tag instance-attribute

embodiment_tag = embodiment_tag

language_key instance-attribute

language_key = language_keys[0]

modality_configs instance-attribute

modality_configs = {
    k: v
    for k, v in (
        all_modality_configs[
            self.embodiment_tag.value
        ].items()
    )
    if k != "rl_info"
}

model instance-attribute

model: Gr00tN1d7 = AutoModel.from_pretrained(
    model_dir, torch_dtype=torch.bfloat16
)

processor instance-attribute

processor: Gr00tN1d7Processor = (
    AutoProcessor.from_pretrained(processor_dir)
)

strict instance-attribute

strict = strict

check_action

check_action(action: dict[str, Any]) -> None

Validate that the action has the correct structure and types.

This method ensures that all required action keys are present and that their data types, shapes, and dimensions match the model's action space.

Expected action structure
  • action: dict[str, np.ndarray[np.float32, (B, T, D)]]
    • B: batch size
    • T: action horizon (number of future action steps)
    • D: action dimension (e.g., joint positions, velocities, gripper state)

Parameters:

Name Type Description Default
action dict[str, Any]

Dictionary containing action arrays for each action key

required

Raises:

Type Description
AssertionError

If any validation check fails

check_observation

check_observation(observation: dict[str, Any]) -> None

Validate that the observation has the correct structure and types.

This method ensures that all required modalities are present and that their data types, shapes, and dimensions match the model's expectations.

Expected observation structure
  • video: dict[str, np.ndarray[np.uint8, (B, T, H, W, C)]]
    • B: batch size
    • T: temporal horizon (number of frames)
    • H, W: image height and width
    • C: number of channels (must be 3 for RGB)
  • state: dict[str, np.ndarray[np.float32, (B, T, D)]]
    • B: batch size
    • T: temporal horizon (number of state observations)
    • D: state dimension
  • language: dict[str, list[list[str]]]
    • Shape: (B, T) where each element is a string
    • T: temporal horizon (typically 1 for language)

Parameters:

Name Type Description Default
observation dict[str, Any]

Dictionary containing video, state, and language modalities

required

Raises:

Type Description
AssertionError

If any validation check fails

get_action

get_action(
    observation: dict[str, Any],
    options: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]

reset

reset(
    options: dict[str, Any] | None = None,
) -> dict[str, Any]

Reset the policy to its initial state.

Parameters:

Name Type Description Default
options dict[str, Any] | None

Dictionary containing the options for the reset

None

Returns:

Type Description
dict[str, Any]

Dictionary containing the info after resetting the policy