Process images with radio1d_size animation. First pass: Extract PCA stats with max radio1d_size. Second pass: Generate frames for all radio1d_size values using those stats.
(loader, model, preprocessor, device, args, dirs, anim_dirs)
| 299 | |
| 300 | |
| 301 | def process_with_animation(loader, model, preprocessor, device, args, dirs, anim_dirs): |
| 302 | """ |
| 303 | Process images with radio1d_size animation. |
| 304 | First pass: Extract PCA stats with max radio1d_size. |
| 305 | Second pass: Generate frames for all radio1d_size values using those stats. |
| 306 | """ |
| 307 | rank_print(f'Processing with radio1d animation: {args.radio1d_start} to {args.radio1d_end} step {args.radio1d_step}') |
| 308 | |
| 309 | # Collect all images first |
| 310 | all_images = [] |
| 311 | ctr = 0 |
| 312 | for batches in loader: |
| 313 | if ctr >= args.n: |
| 314 | break |
| 315 | for images, _ in batches: |
| 316 | images = images.to(device=device, non_blocking=True) |
| 317 | all_images.append(images) |
| 318 | ctr += len(images) |
| 319 | |
| 320 | rank_print(f'Collected {len(all_images)} batches') |
| 321 | |
| 322 | # First pass: Get PCA stats from the highest radio1d_size |
| 323 | rank_print(f'First pass: Computing PCA stats with radio1d_size={args.radio1d_end}') |
| 324 | pca_stats_list = [] |
| 325 | |
| 326 | for images in tqdm(all_images, desc='Computing PCA stats'): |
| 327 | with torch.autocast(device.type, dtype=torch.bfloat16): |
| 328 | p_images = preprocessor(images) |
| 329 | |
| 330 | if args.intermediates: |
| 331 | outputs = model.forward_intermediates( |
| 332 | p_images, |
| 333 | indices=args.intermediates, |
| 334 | return_prefix_tokens=False, |
| 335 | norm=False, |
| 336 | stop_early=True, |
| 337 | output_fmt='NCHW', |
| 338 | intermediates_only=True, |
| 339 | aggregation=args.intermediate_aggregation, |
| 340 | norm_alpha_scheme="none", |
| 341 | ) |
| 342 | all_feat = outputs |
| 343 | else: |
| 344 | kwargs = {} |
| 345 | if args.neck: |
| 346 | kwargs['num_tokens'] = args.radio1d_end |
| 347 | output = model(p_images, feature_fmt='NCHW', **kwargs) |
| 348 | if args.adaptor_name: |
| 349 | all_feat = [ |
| 350 | output['backbone'].features, |
| 351 | output[args.adaptor_name].features, |
| 352 | ] |
| 353 | else: |
| 354 | all_feat = [output[1]] |
| 355 | |
| 356 | all_feat = [rearrange(f, 'b c h w -> b h w c').float() for f in all_feat] |
| 357 | all_feat = list(zip(*all_feat)) |
| 358 |
no test coverage detected
searching dependent graphs…