(
model_name: str,
pretrained: str = '',
precision: str = 'fp32',
device: torch.device = torch.device('cpu'),
jit: bool = False,
force_quick_gelu: bool = False,
pretrained_image: bool = False,
cache_dir: Optional[str] = None,
args=None,
)
| 87 | |
| 88 | |
| 89 | def create_model( |
| 90 | model_name: str, |
| 91 | pretrained: str = '', |
| 92 | precision: str = 'fp32', |
| 93 | device: torch.device = torch.device('cpu'), |
| 94 | jit: bool = False, |
| 95 | force_quick_gelu: bool = False, |
| 96 | pretrained_image: bool = False, |
| 97 | cache_dir: Optional[str] = None, |
| 98 | args=None, |
| 99 | ): |
| 100 | # for callers using old naming with / in ViT names |
| 101 | model_name = model_name.replace('/', '-') |
| 102 | |
| 103 | if pretrained.lower() == 'openai': |
| 104 | logging.info(f'Loading pretrained {model_name} from OpenAI.') |
| 105 | model = load_openai_model( |
| 106 | model_name, device=device, jit=jit, cache_dir=cache_dir) |
| 107 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 |
| 108 | if precision == "amp" or precision == "fp32": |
| 109 | model = model.float() |
| 110 | else: |
| 111 | if model_name in _MODEL_CONFIGS: |
| 112 | logging.info(f'Loading {model_name} model config.') |
| 113 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) |
| 114 | else: |
| 115 | logging.error( |
| 116 | f'Model config for {model_name} not found; available models {list_models()}.') |
| 117 | raise RuntimeError(f'Model config for {model_name} not found.') |
| 118 | |
| 119 | if force_quick_gelu: |
| 120 | # override for use of QuickGELU on non-OpenAI transformer models |
| 121 | model_cfg["quick_gelu"] = True |
| 122 | |
| 123 | if pretrained_image: |
| 124 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): |
| 125 | # pretrained weight loading for timm models set via vision_cfg |
| 126 | model_cfg['vision_cfg']['timm_model_pretrained'] = True |
| 127 | else: |
| 128 | assert False, 'pretrained image towers currently only supported for timm models' |
| 129 | |
| 130 | if args is not None: |
| 131 | model_cfg['mask_image'] = getattr(args, 'prune_image', False) |
| 132 | model_cfg['mask_text'] = getattr(args, 'prune_text', False) |
| 133 | model_cfg['sparsity_warmup'] = getattr( |
| 134 | args, 'sparsity_warmup', 1000) |
| 135 | model_cfg['start_sparsity'] = getattr(args, 'start_sparsity', 0.0) |
| 136 | model_cfg['sparsity'] = getattr(args, 'target_sparsity', 0.25) |
| 137 | logging.info( |
| 138 | f'model sparsity varies from {model_cfg["start_sparsity"]} to {model_cfg["sparsity"]}, sparsity warmup steps: {model_cfg["sparsity_warmup"]}') |
| 139 | |
| 140 | logging.info(str(model_cfg)) |
| 141 | model = CLIP(**model_cfg) |
| 142 | |
| 143 | pretrained_cfg = {} |
| 144 | if pretrained: |
| 145 | checkpoint_path = '' |
| 146 | pretrained_cfg = get_pretrained_cfg(model_name, pretrained) |
no test coverage detected