MCPcopy
hub / github.com/NVIDIA/Isaac-GR00T / action_head_tensorrt_forward

Function action_head_tensorrt_forward

scripts/deployment/trt_model_forward.py:373–500  ·  view source on GitHub ↗

Replace ActionHead.get_action() with TRT-accelerated inference. VLLN (LayerNorm) stays in PyTorch. State Encoder, Action Encoder, DiT, and Action Decoder are replaced with TRT engines. N1.7 change: state is reshaped from [B, state_history_length, max_state_dim] to [B, 1, state_histo

(self, backbone_output, action_input, options=None)

Source from the content-addressed store, hash-verified

371
372
373def action_head_tensorrt_forward(self, backbone_output, action_input, options=None):
374 """Replace ActionHead.get_action() with TRT-accelerated inference.
375 VLLN (LayerNorm) stays in PyTorch. State Encoder, Action Encoder,
376 DiT, and Action Decoder are replaced with TRT engines.
377
378 N1.7 change: state is reshaped from [B, state_history_length, max_state_dim]
379 to [B, 1, state_history_length * max_state_dim] before the state encoder.
380
381 Args:
382 self: ActionHead instance (monkey-patched)
383 backbone_output: BatchFeature with backbone_features, backbone_attention_mask, image_mask
384 action_input: BatchFeature with state, embodiment_id
385 """
386 # --- VLLN (PyTorch) + vl_self_attention (TRT if available, else PyTorch) ---
387 backbone_features = backbone_output.backbone_features
388 backbone_features = self.vlln(backbone_features)
389 if hasattr(self, "vl_sa_engine") and self.vl_sa_engine is not None:
390 engine_dtype = torch.bfloat16
391 if backbone_features.dtype != engine_dtype:
392 backbone_features = backbone_features.to(engine_dtype)
393 self.vl_sa_engine.set_runtime_tensor_shape("hidden_states", backbone_features.shape)
394 backbone_features = self.vl_sa_engine(backbone_features)["output"]
395 else:
396 backbone_features = self.vl_self_attention(backbone_features)
397 vl_embs = backbone_features
398
399 embodiment_id = action_input.embodiment_id
400 batch_size = vl_embs.shape[0]
401 device = vl_embs.device
402
403 engine_dtype = torch.bfloat16
404
405 # Ensure consistent dtypes
406 if vl_embs.dtype != engine_dtype:
407 vl_embs = vl_embs.to(engine_dtype)
408 if action_input.state.dtype != engine_dtype:
409 action_input.state = action_input.state.to(engine_dtype)
410 if embodiment_id.dtype != torch.int64:
411 embodiment_id = embodiment_id.to(torch.int64)
412
413 # --- State history reshape (N1.7) ---
414 # N1.7: state comes as [B, state_history_length, max_state_dim]
415 # Flatten to [B, 1, state_history_length * max_state_dim] for the encoder
416 state = action_input.state
417 if state.ndim == 3 and state.shape[1] > 1:
418 state = state.view(state.shape[0], 1, -1)
419 elif state.ndim == 3 and state.shape[1] == 1:
420 # Already [B, 1, dim] — state_history_length=1
421 pass
422 else:
423 # Unexpected shape, pass through
424 logger.warning(f"Unexpected state shape: {state.shape}")
425
426 # --- State Encoder TRT ---
427 self.state_encoder_engine.set_runtime_tensor_shape("state", state.shape)
428 self.state_encoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
429 state_features = self.state_encoder_engine(state, embodiment_id)["output"]
430

Callers

nothing calls this directly

Calls 2

toMethod · 0.45

Tested by

no test coverage detected