Resolve pretrained configuration from various sources.
(
variant: str,
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
)
| 346 | |
| 347 | |
| 348 | def resolve_pretrained_cfg( |
| 349 | variant: str, |
| 350 | pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None, |
| 351 | pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, |
| 352 | ) -> PretrainedCfg: |
| 353 | """Resolve pretrained configuration from various sources.""" |
| 354 | model_with_tag = variant |
| 355 | pretrained_tag = None |
| 356 | if pretrained_cfg: |
| 357 | if isinstance(pretrained_cfg, dict): |
| 358 | # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg |
| 359 | pretrained_cfg = PretrainedCfg(**pretrained_cfg) |
| 360 | elif isinstance(pretrained_cfg, str): |
| 361 | pretrained_tag = pretrained_cfg |
| 362 | pretrained_cfg = None |
| 363 | |
| 364 | # fallback to looking up pretrained cfg in model registry by variant identifier |
| 365 | if not pretrained_cfg: |
| 366 | if pretrained_tag: |
| 367 | model_with_tag = '.'.join([variant, pretrained_tag]) |
| 368 | pretrained_cfg = get_pretrained_cfg(model_with_tag) |
| 369 | |
| 370 | if not pretrained_cfg: |
| 371 | _logger.warning( |
| 372 | f"No pretrained configuration specified for {model_with_tag} model. Using a default." |
| 373 | f" Please add a config to the model pretrained_cfg registry or pass explicitly.") |
| 374 | pretrained_cfg = PretrainedCfg() # instance with defaults |
| 375 | |
| 376 | pretrained_cfg_overlay = pretrained_cfg_overlay or {} |
| 377 | if not pretrained_cfg.architecture: |
| 378 | pretrained_cfg_overlay.setdefault('architecture', variant) |
| 379 | pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) |
| 380 | |
| 381 | return pretrained_cfg |
| 382 | |
| 383 | |
| 384 | def build_model_with_cfg( |
no test coverage detected