vllm_omni.diffusion.models.gr00t.policy ¶
Gr00tPolicy ¶
Core policy class for Gr00t model inference.
This policy handles the end-to-end inference pipeline: 1. Validates input observations 2. Processes observations with pretrained VLA processor 3. Runs model inference 4. Decodes and returns actions
The policy expects observations with specific modalities (video, state, language) and returns actions in the format defined by the model's modality configuration.
modality_configs instance-attribute ¶
modality_configs = {
k: v
for k, v in (
all_modality_configs[
self.embodiment_tag.value
].items()
)
if k != "rl_info"
}
model instance-attribute ¶
model: Gr00tN1d7 = AutoModel.from_pretrained(
model_dir, torch_dtype=torch.bfloat16
)
processor instance-attribute ¶
processor: Gr00tN1d7Processor = (
AutoProcessor.from_pretrained(processor_dir)
)
check_action ¶
Validate that the action has the correct structure and types.
This method ensures that all required action keys are present and that their data types, shapes, and dimensions match the model's action space.
Expected action structure
- action: dict[str, np.ndarray[np.float32, (B, T, D)]]
- B: batch size
- T: action horizon (number of future action steps)
- D: action dimension (e.g., joint positions, velocities, gripper state)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action | dict[str, Any] | Dictionary containing action arrays for each action key | required |
Raises:
| Type | Description |
|---|---|
AssertionError | If any validation check fails |
check_observation ¶
Validate that the observation has the correct structure and types.
This method ensures that all required modalities are present and that their data types, shapes, and dimensions match the model's expectations.
Expected observation structure
- video: dict[str, np.ndarray[np.uint8, (B, T, H, W, C)]]
- B: batch size
- T: temporal horizon (number of frames)
- H, W: image height and width
- C: number of channels (must be 3 for RGB)
- state: dict[str, np.ndarray[np.float32, (B, T, D)]]
- B: batch size
- T: temporal horizon (number of state observations)
- D: state dimension
- language: dict[str, list[list[str]]]
- Shape: (B, T) where each element is a string
- T: temporal horizon (typically 1 for language)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observation | dict[str, Any] | Dictionary containing video, state, and language modalities | required |
Raises:
| Type | Description |
|---|---|
AssertionError | If any validation check fails |
get_action ¶
get_action(
observation: dict[str, Any],
options: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]