MCPcopy
hub / github.com/microsoft/Cream / create_model

Function create_model

TinyCLIP/src/open_clip/factory.py:89–180  ·  view source on GitHub ↗
(
        model_name: str,
        pretrained: str = '',
        precision: str = 'fp32',
        device: torch.device = torch.device('cpu'),
        jit: bool = False,
        force_quick_gelu: bool = False,
        pretrained_image: bool = False,
        cache_dir: Optional[str] = None,
        args=None,
)

Source from the content-addressed store, hash-verified

87
88
89def create_model(
90 model_name: str,
91 pretrained: str = '',
92 precision: str = 'fp32',
93 device: torch.device = torch.device('cpu'),
94 jit: bool = False,
95 force_quick_gelu: bool = False,
96 pretrained_image: bool = False,
97 cache_dir: Optional[str] = None,
98 args=None,
99):
100 # for callers using old naming with / in ViT names
101 model_name = model_name.replace('/', '-')
102
103 if pretrained.lower() == 'openai':
104 logging.info(f'Loading pretrained {model_name} from OpenAI.')
105 model = load_openai_model(
106 model_name, device=device, jit=jit, cache_dir=cache_dir)
107 # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
108 if precision == "amp" or precision == "fp32":
109 model = model.float()
110 else:
111 if model_name in _MODEL_CONFIGS:
112 logging.info(f'Loading {model_name} model config.')
113 model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
114 else:
115 logging.error(
116 f'Model config for {model_name} not found; available models {list_models()}.')
117 raise RuntimeError(f'Model config for {model_name} not found.')
118
119 if force_quick_gelu:
120 # override for use of QuickGELU on non-OpenAI transformer models
121 model_cfg["quick_gelu"] = True
122
123 if pretrained_image:
124 if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
125 # pretrained weight loading for timm models set via vision_cfg
126 model_cfg['vision_cfg']['timm_model_pretrained'] = True
127 else:
128 assert False, 'pretrained image towers currently only supported for timm models'
129
130 if args is not None:
131 model_cfg['mask_image'] = getattr(args, 'prune_image', False)
132 model_cfg['mask_text'] = getattr(args, 'prune_text', False)
133 model_cfg['sparsity_warmup'] = getattr(
134 args, 'sparsity_warmup', 1000)
135 model_cfg['start_sparsity'] = getattr(args, 'start_sparsity', 0.0)
136 model_cfg['sparsity'] = getattr(args, 'target_sparsity', 0.25)
137 logging.info(
138 f'model sparsity varies from {model_cfg["start_sparsity"]} to {model_cfg["sparsity"]}, sparsity warmup steps: {model_cfg["sparsity_warmup"]}')
139
140 logging.info(str(model_cfg))
141 model = CLIP(**model_cfg)
142
143 pretrained_cfg = {}
144 if pretrained:
145 checkpoint_path = ''
146 pretrained_cfg = get_pretrained_cfg(model_name, pretrained)

Callers 6

mainFunction · 0.85
mainFunction · 0.85
mainFunction · 0.85
mainFunction · 0.85
mainFunction · 0.85

Calls 9

CLIPClass · 0.90
convert_weights_to_fp16Function · 0.90
load_openai_modelFunction · 0.85
list_modelsFunction · 0.85
get_pretrained_cfgFunction · 0.85
download_pretrainedFunction · 0.85
toMethod · 0.80
load_checkpointFunction · 0.70
getMethod · 0.45

Tested by

no test coverage detected