Replace ActionHead.get_action() with TRT-accelerated inference. VLLN (LayerNorm) stays in PyTorch. State Encoder, Action Encoder, DiT, and Action Decoder are replaced with TRT engines. N1.7 change: state is reshaped from [B, state_history_length, max_state_dim] to [B, 1, state_histo
(self, backbone_output, action_input, options=None)
| 371 | |
| 372 | |
| 373 | def action_head_tensorrt_forward(self, backbone_output, action_input, options=None): |
| 374 | """Replace ActionHead.get_action() with TRT-accelerated inference. |
| 375 | VLLN (LayerNorm) stays in PyTorch. State Encoder, Action Encoder, |
| 376 | DiT, and Action Decoder are replaced with TRT engines. |
| 377 | |
| 378 | N1.7 change: state is reshaped from [B, state_history_length, max_state_dim] |
| 379 | to [B, 1, state_history_length * max_state_dim] before the state encoder. |
| 380 | |
| 381 | Args: |
| 382 | self: ActionHead instance (monkey-patched) |
| 383 | backbone_output: BatchFeature with backbone_features, backbone_attention_mask, image_mask |
| 384 | action_input: BatchFeature with state, embodiment_id |
| 385 | """ |
| 386 | # --- VLLN (PyTorch) + vl_self_attention (TRT if available, else PyTorch) --- |
| 387 | backbone_features = backbone_output.backbone_features |
| 388 | backbone_features = self.vlln(backbone_features) |
| 389 | if hasattr(self, "vl_sa_engine") and self.vl_sa_engine is not None: |
| 390 | engine_dtype = torch.bfloat16 |
| 391 | if backbone_features.dtype != engine_dtype: |
| 392 | backbone_features = backbone_features.to(engine_dtype) |
| 393 | self.vl_sa_engine.set_runtime_tensor_shape("hidden_states", backbone_features.shape) |
| 394 | backbone_features = self.vl_sa_engine(backbone_features)["output"] |
| 395 | else: |
| 396 | backbone_features = self.vl_self_attention(backbone_features) |
| 397 | vl_embs = backbone_features |
| 398 | |
| 399 | embodiment_id = action_input.embodiment_id |
| 400 | batch_size = vl_embs.shape[0] |
| 401 | device = vl_embs.device |
| 402 | |
| 403 | engine_dtype = torch.bfloat16 |
| 404 | |
| 405 | # Ensure consistent dtypes |
| 406 | if vl_embs.dtype != engine_dtype: |
| 407 | vl_embs = vl_embs.to(engine_dtype) |
| 408 | if action_input.state.dtype != engine_dtype: |
| 409 | action_input.state = action_input.state.to(engine_dtype) |
| 410 | if embodiment_id.dtype != torch.int64: |
| 411 | embodiment_id = embodiment_id.to(torch.int64) |
| 412 | |
| 413 | # --- State history reshape (N1.7) --- |
| 414 | # N1.7: state comes as [B, state_history_length, max_state_dim] |
| 415 | # Flatten to [B, 1, state_history_length * max_state_dim] for the encoder |
| 416 | state = action_input.state |
| 417 | if state.ndim == 3 and state.shape[1] > 1: |
| 418 | state = state.view(state.shape[0], 1, -1) |
| 419 | elif state.ndim == 3 and state.shape[1] == 1: |
| 420 | # Already [B, 1, dim] — state_history_length=1 |
| 421 | pass |
| 422 | else: |
| 423 | # Unexpected shape, pass through |
| 424 | logger.warning(f"Unexpected state shape: {state.shape}") |
| 425 | |
| 426 | # --- State Encoder TRT --- |
| 427 | self.state_encoder_engine.set_runtime_tensor_shape("state", state.shape) |
| 428 | self.state_encoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape) |
| 429 | state_features = self.state_encoder_engine(state, embodiment_id)["output"] |
| 430 |
nothing calls this directly
no test coverage detected