Main function for the VGGT demo with viser for 3D visualization. This function: 1. Loads the VGGT model 2. Processes input images from the specified folder 3. Runs inference to generate 3D points and camera poses 4. Optionally applies sky segmentation to filter out sky poin
()
| 319 | |
| 320 | |
| 321 | def main(): |
| 322 | """ |
| 323 | Main function for the VGGT demo with viser for 3D visualization. |
| 324 | |
| 325 | This function: |
| 326 | 1. Loads the VGGT model |
| 327 | 2. Processes input images from the specified folder |
| 328 | 3. Runs inference to generate 3D points and camera poses |
| 329 | 4. Optionally applies sky segmentation to filter out sky points |
| 330 | 5. Visualizes the results using viser |
| 331 | |
| 332 | Command-line arguments: |
| 333 | --image_folder: Path to folder containing input images |
| 334 | --use_point_map: Use point map instead of depth-based points |
| 335 | --background_mode: Run the viser server in background mode |
| 336 | --port: Port number for the viser server |
| 337 | --conf_threshold: Initial percentage of low-confidence points to filter out |
| 338 | --mask_sky: Apply sky segmentation to filter out sky points |
| 339 | """ |
| 340 | args = parser.parse_args() |
| 341 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 342 | print(f"Using device: {device}") |
| 343 | |
| 344 | print("Initializing and loading VGGT model...") |
| 345 | # model = VGGT.from_pretrained("facebook/VGGT-1B") |
| 346 | |
| 347 | model = VGGT() |
| 348 | _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" |
| 349 | model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) |
| 350 | |
| 351 | model.eval() |
| 352 | model = model.to(device) |
| 353 | |
| 354 | # Use the provided image folder path |
| 355 | print(f"Loading images from {args.image_folder}...") |
| 356 | image_names = glob.glob(os.path.join(args.image_folder, "*")) |
| 357 | print(f"Found {len(image_names)} images") |
| 358 | |
| 359 | images = load_and_preprocess_images(image_names).to(device) |
| 360 | print(f"Preprocessed images shape: {images.shape}") |
| 361 | |
| 362 | print("Running inference...") |
| 363 | dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 |
| 364 | |
| 365 | with torch.no_grad(): |
| 366 | with torch.cuda.amp.autocast(dtype=dtype): |
| 367 | predictions = model(images) |
| 368 | |
| 369 | print("Converting pose encoding to extrinsic and intrinsic matrices...") |
| 370 | extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) |
| 371 | predictions["extrinsic"] = extrinsic |
| 372 | predictions["intrinsic"] = intrinsic |
| 373 | |
| 374 | print("Processing model outputs...") |
| 375 | for key in predictions.keys(): |
| 376 | if isinstance(predictions[key], torch.Tensor): |
| 377 | predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension and convert to numpy |
| 378 |
no test coverage detected