MCPcopy
hub / github.com/NVlabs/RADIO / load_model

Function load_model

examples/common/model_loader.py:339–501  ·  view source on GitHub ↗
(version: str, adaptor_names: str = None, use_huggingface: bool = False, use_local_lib: bool = True,
               device: torch.device = None, return_spatial_features: bool = True, force_reload: bool = False, neck_name: str = None,
               torchhub_repo="NVlabs/RADIO", **kwargs)

Source from the content-addressed store, hash-verified

337# and once it completes, it allows all other ranks to execute, using the now cached weights.
338@rank_gate
339def load_model(version: str, adaptor_names: str = None, use_huggingface: bool = False, use_local_lib: bool = True,
340 device: torch.device = None, return_spatial_features: bool = True, force_reload: bool = False, neck_name: str = None,
341 torchhub_repo="NVlabs/RADIO", **kwargs):
342 kwargs = {k: v for k, v in kwargs.items() if v is not None}
343
344 if os.path.isfile(version) or 'radio' in version:
345 if use_huggingface:
346 from transformers import AutoModel, AutoConfig
347 hf_repo = 'E-RADIO' if 'eradio' in version else 'RADIO'
348 hf_repo = f"nvidia/{hf_repo}"
349 config = AutoConfig.from_pretrained(
350 hf_repo,
351 trust_remote_code=True,
352 version=version,
353 adaptor_names=adaptor_names,
354 **kwargs,
355 )
356 model: nn.Module = AutoModel.from_pretrained(hf_repo, config=config, trust_remote_code=True, **kwargs)
357 elif use_local_lib:
358 from hubconf import radio_model
359 model, chk = radio_model(version=version, progress=True, adaptor_names=adaptor_names, return_checkpoint=True, neck_name=neck_name, **kwargs)
360 else:
361 model, chk = torch.hub.load(torchhub_repo, 'radio_model', version=version, progress=True,
362 adaptor_names=adaptor_names, return_spatial_features=return_spatial_features,
363 force_reload=force_reload,
364 return_checkpoint=True, neck_name=neck_name, **kwargs,
365 )
366
367 preprocessor = model.make_preprocessor_external()
368 info = ModelInfo(model_class='RADIO', model_subtype=version.replace('/', '_'), checkpoint=chk)
369 elif version.startswith('dinov2'):
370 model = torch.hub.load('facebookresearch/dinov2', version, pretrained=True, force_reload=force_reload, **kwargs)
371 model = DinoWrapper(model)
372
373 preprocessor = InputConditioner(1.0, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
374 info = ModelInfo(model_class='DINOv2', model_subtype=version.replace('dinov2_', ''))
375 elif version.startswith('dinov3'):
376 _, chk_path = version.split(',')
377 fname = os.path.basename(chk_path)
378 model_version = fname[:fname.index('_pretrain')]
379 model: nn.Module = torch.hub.load('facebookresearch/dinov3', model_version, pretrained=False)
380
381 chk = torch.load(chk_path, map_location='cpu')
382 model.load_state_dict(chk, strict=True)
383 model = DinoWrapper(model, patch_attn=False)
384
385 preprocessor = InputConditioner(1.0, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
386 info = ModelInfo(model_class='DINOv3', model_subtype=model_version.replace('dinov3_', ''))
387 elif version.startswith('open_clip'):
388 import open_clip
389 _, model_arch, pretrained = version.split(',')
390 model = open_clip.create_model(model_arch, pretrained, device=device)
391 viz_model = model.visual
392
393 preprocessor = InputConditioner(1.0,
394 getattr(viz_model, 'image_mean', open_clip.OPENAI_DATASET_MEAN),
395 getattr(viz_model, 'image_std', open_clip.OPENAI_DATASET_STD),
396 )

Callers 9

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
get_feature_matrixFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 12

radio_modelFunction · 0.90
InputConditionerClass · 0.90
ModelInfoClass · 0.85
CLIPWrapperClass · 0.85
SAMWrapperClass · 0.85
InternViTWrapperClass · 0.85
SigLIP2WrapperClass · 0.85
toMethod · 0.80
DinoWrapperClass · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…