Load LingbotMap (GCTStream) model from checkpoint.
(self)
| 106 | self._load_model() |
| 107 | |
| 108 | def _load_model(self): |
| 109 | """Load LingbotMap (GCTStream) model from checkpoint.""" |
| 110 | if self.mode == 'windowed': |
| 111 | from lingbot_map.models.gct_stream_window import GCTStream |
| 112 | else: |
| 113 | from lingbot_map.models.gct_stream import GCTStream |
| 114 | |
| 115 | print(f" → Building LingbotMap model (mode: {self.mode})") |
| 116 | self.model = GCTStream( |
| 117 | img_size=self.image_size, |
| 118 | patch_size=self.patch_size, |
| 119 | enable_3d_rope=self.enable_3d_rope, |
| 120 | max_frame_num=self.max_frame_num, |
| 121 | kv_cache_sliding_window=self.kv_cache_sliding_window, |
| 122 | kv_cache_scale_frames=self.kv_cache_scale_frames, |
| 123 | kv_cache_cross_frame_special=True, |
| 124 | kv_cache_include_scale_frames=True, |
| 125 | use_sdpa=self.use_sdpa, |
| 126 | ) |
| 127 | |
| 128 | if self.checkpoint: |
| 129 | print(f" → Loading checkpoint: {self.checkpoint}") |
| 130 | ckpt = torch.load(self.checkpoint, map_location=self.device, weights_only=False) |
| 131 | state_dict = ckpt.get("model", ckpt) |
| 132 | missing, unexpected = self.model.load_state_dict(state_dict, strict=False) |
| 133 | if missing: |
| 134 | print(f" Missing keys: {len(missing)}") |
| 135 | if unexpected: |
| 136 | print(f" Unexpected keys: {len(unexpected)}") |
| 137 | print(" Checkpoint loaded.") |
| 138 | |
| 139 | self.model = self.model.to(self.device).eval() |
| 140 | |
| 141 | def _prepare_images(self, rgb_list): |
| 142 | """Convert list of HxWx3 uint8 numpy arrays to [S, 3, H, W] tensor in [0, 1].""" |