MCPcopy Index your code
hub / github.com/Robbyant/lingbot-map / load_model

Function load_model

demo.py:131–163  ·  view source on GitHub ↗

Load GCTStream model from checkpoint.

(args, device)

Source from the content-addressed store, hash-verified

129# =============================================================================
130
131def load_model(args, device):
132 """Load GCTStream model from checkpoint."""
133 if getattr(args, "mode", "streaming") == "windowed":
134 from lingbot_map.models.gct_stream_window import GCTStream
135 else:
136 from lingbot_map.models.gct_stream import GCTStream
137
138 print("Building model...")
139 model = GCTStream(
140 img_size=args.image_size,
141 patch_size=args.patch_size,
142 enable_3d_rope=args.enable_3d_rope,
143 max_frame_num=args.max_frame_num,
144 kv_cache_sliding_window=args.kv_cache_sliding_window,
145 kv_cache_scale_frames=args.num_scale_frames,
146 kv_cache_cross_frame_special=True,
147 kv_cache_include_scale_frames=True,
148 use_sdpa=args.use_sdpa,
149 camera_num_iterations=args.camera_num_iterations,
150 )
151
152 if args.model_path:
153 print(f"Loading checkpoint: {args.model_path}")
154 ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
155 state_dict = ckpt.get("model", ckpt)
156 missing, unexpected = model.load_state_dict(state_dict, strict=False)
157 if missing:
158 print(f" Missing keys: {len(missing)}")
159 if unexpected:
160 print(f" Unexpected keys: {len(unexpected)}")
161 print(" Checkpoint loaded.")
162
163 return model.to(device).eval()
164
165
166# =============================================================================

Callers 1

mainFunction · 0.70

Calls 2

GCTStreamClass · 0.90
loadMethod · 0.45

Tested by

no test coverage detected