MCPcopy
hub / github.com/NVlabs/RADIO / process_with_animation

Function process_with_animation

examples/visualize_features.py:301–478  ·  view source on GitHub ↗

Process images with radio1d_size animation. First pass: Extract PCA stats with max radio1d_size. Second pass: Generate frames for all radio1d_size values using those stats.

(loader, model, preprocessor, device, args, dirs, anim_dirs)

Source from the content-addressed store, hash-verified

299
300
301def process_with_animation(loader, model, preprocessor, device, args, dirs, anim_dirs):
302 """
303 Process images with radio1d_size animation.
304 First pass: Extract PCA stats with max radio1d_size.
305 Second pass: Generate frames for all radio1d_size values using those stats.
306 """
307 rank_print(f'Processing with radio1d animation: {args.radio1d_start} to {args.radio1d_end} step {args.radio1d_step}')
308
309 # Collect all images first
310 all_images = []
311 ctr = 0
312 for batches in loader:
313 if ctr >= args.n:
314 break
315 for images, _ in batches:
316 images = images.to(device=device, non_blocking=True)
317 all_images.append(images)
318 ctr += len(images)
319
320 rank_print(f'Collected {len(all_images)} batches')
321
322 # First pass: Get PCA stats from the highest radio1d_size
323 rank_print(f'First pass: Computing PCA stats with radio1d_size={args.radio1d_end}')
324 pca_stats_list = []
325
326 for images in tqdm(all_images, desc='Computing PCA stats'):
327 with torch.autocast(device.type, dtype=torch.bfloat16):
328 p_images = preprocessor(images)
329
330 if args.intermediates:
331 outputs = model.forward_intermediates(
332 p_images,
333 indices=args.intermediates,
334 return_prefix_tokens=False,
335 norm=False,
336 stop_early=True,
337 output_fmt='NCHW',
338 intermediates_only=True,
339 aggregation=args.intermediate_aggregation,
340 norm_alpha_scheme="none",
341 )
342 all_feat = outputs
343 else:
344 kwargs = {}
345 if args.neck:
346 kwargs['num_tokens'] = args.radio1d_end
347 output = model(p_images, feature_fmt='NCHW', **kwargs)
348 if args.adaptor_name:
349 all_feat = [
350 output['backbone'].features,
351 output[args.adaptor_name].features,
352 ]
353 else:
354 all_feat = [output[1]]
355
356 all_feat = [rearrange(f, 'b c h w -> b h w c').float() for f in all_feat]
357 all_feat = list(zip(*all_feat))
358

Callers 1

mainFunction · 0.85

Calls 7

rank_printFunction · 0.90
get_pca_mapFunction · 0.85
add_text_overlayFunction · 0.85
save_animated_imageFunction · 0.85
toMethod · 0.80
forward_intermediatesMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…