Load GCTStream model from checkpoint.
(args, device)
| 129 | # ============================================================================= |
| 130 | |
| 131 | def load_model(args, device): |
| 132 | """Load GCTStream model from checkpoint.""" |
| 133 | if getattr(args, "mode", "streaming") == "windowed": |
| 134 | from lingbot_map.models.gct_stream_window import GCTStream |
| 135 | else: |
| 136 | from lingbot_map.models.gct_stream import GCTStream |
| 137 | |
| 138 | print("Building model...") |
| 139 | model = GCTStream( |
| 140 | img_size=args.image_size, |
| 141 | patch_size=args.patch_size, |
| 142 | enable_3d_rope=args.enable_3d_rope, |
| 143 | max_frame_num=args.max_frame_num, |
| 144 | kv_cache_sliding_window=args.kv_cache_sliding_window, |
| 145 | kv_cache_scale_frames=args.num_scale_frames, |
| 146 | kv_cache_cross_frame_special=True, |
| 147 | kv_cache_include_scale_frames=True, |
| 148 | use_sdpa=args.use_sdpa, |
| 149 | camera_num_iterations=args.camera_num_iterations, |
| 150 | ) |
| 151 | |
| 152 | if args.model_path: |
| 153 | print(f"Loading checkpoint: {args.model_path}") |
| 154 | ckpt = torch.load(args.model_path, map_location=device, weights_only=False) |
| 155 | state_dict = ckpt.get("model", ckpt) |
| 156 | missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| 157 | if missing: |
| 158 | print(f" Missing keys: {len(missing)}") |
| 159 | if unexpected: |
| 160 | print(f" Unexpected keys: {len(unexpected)}") |
| 161 | print(" Checkpoint loaded.") |
| 162 | |
| 163 | return model.to(device).eval() |
| 164 | |
| 165 | |
| 166 | # ============================================================================= |