(version: str, adaptor_names: str = None, use_huggingface: bool = False, use_local_lib: bool = True,
device: torch.device = None, return_spatial_features: bool = True, force_reload: bool = False, neck_name: str = None,
torchhub_repo="NVlabs/RADIO", **kwargs)
| 337 | # and once it completes, it allows all other ranks to execute, using the now cached weights. |
| 338 | @rank_gate |
| 339 | def load_model(version: str, adaptor_names: str = None, use_huggingface: bool = False, use_local_lib: bool = True, |
| 340 | device: torch.device = None, return_spatial_features: bool = True, force_reload: bool = False, neck_name: str = None, |
| 341 | torchhub_repo="NVlabs/RADIO", **kwargs): |
| 342 | kwargs = {k: v for k, v in kwargs.items() if v is not None} |
| 343 | |
| 344 | if os.path.isfile(version) or 'radio' in version: |
| 345 | if use_huggingface: |
| 346 | from transformers import AutoModel, AutoConfig |
| 347 | hf_repo = 'E-RADIO' if 'eradio' in version else 'RADIO' |
| 348 | hf_repo = f"nvidia/{hf_repo}" |
| 349 | config = AutoConfig.from_pretrained( |
| 350 | hf_repo, |
| 351 | trust_remote_code=True, |
| 352 | version=version, |
| 353 | adaptor_names=adaptor_names, |
| 354 | **kwargs, |
| 355 | ) |
| 356 | model: nn.Module = AutoModel.from_pretrained(hf_repo, config=config, trust_remote_code=True, **kwargs) |
| 357 | elif use_local_lib: |
| 358 | from hubconf import radio_model |
| 359 | model, chk = radio_model(version=version, progress=True, adaptor_names=adaptor_names, return_checkpoint=True, neck_name=neck_name, **kwargs) |
| 360 | else: |
| 361 | model, chk = torch.hub.load(torchhub_repo, 'radio_model', version=version, progress=True, |
| 362 | adaptor_names=adaptor_names, return_spatial_features=return_spatial_features, |
| 363 | force_reload=force_reload, |
| 364 | return_checkpoint=True, neck_name=neck_name, **kwargs, |
| 365 | ) |
| 366 | |
| 367 | preprocessor = model.make_preprocessor_external() |
| 368 | info = ModelInfo(model_class='RADIO', model_subtype=version.replace('/', '_'), checkpoint=chk) |
| 369 | elif version.startswith('dinov2'): |
| 370 | model = torch.hub.load('facebookresearch/dinov2', version, pretrained=True, force_reload=force_reload, **kwargs) |
| 371 | model = DinoWrapper(model) |
| 372 | |
| 373 | preprocessor = InputConditioner(1.0, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
| 374 | info = ModelInfo(model_class='DINOv2', model_subtype=version.replace('dinov2_', '')) |
| 375 | elif version.startswith('dinov3'): |
| 376 | _, chk_path = version.split(',') |
| 377 | fname = os.path.basename(chk_path) |
| 378 | model_version = fname[:fname.index('_pretrain')] |
| 379 | model: nn.Module = torch.hub.load('facebookresearch/dinov3', model_version, pretrained=False) |
| 380 | |
| 381 | chk = torch.load(chk_path, map_location='cpu') |
| 382 | model.load_state_dict(chk, strict=True) |
| 383 | model = DinoWrapper(model, patch_attn=False) |
| 384 | |
| 385 | preprocessor = InputConditioner(1.0, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
| 386 | info = ModelInfo(model_class='DINOv3', model_subtype=model_version.replace('dinov3_', '')) |
| 387 | elif version.startswith('open_clip'): |
| 388 | import open_clip |
| 389 | _, model_arch, pretrained = version.split(',') |
| 390 | model = open_clip.create_model(model_arch, pretrained, device=device) |
| 391 | viz_model = model.visual |
| 392 | |
| 393 | preprocessor = InputConditioner(1.0, |
| 394 | getattr(viz_model, 'image_mean', open_clip.OPENAI_DATASET_MEAN), |
| 395 | getattr(viz_model, 'image_std', open_clip.OPENAI_DATASET_STD), |
| 396 | ) |
no test coverage detected
searching dependent graphs…