MCPcopy
hub / github.com/mlfoundations/open_clip / load_openai_model

Function load_openai_model

src/open_clip/openai.py:24–90  ·  view source on GitHub ↗

Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict precision: str Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Un

(
        name: str,
        precision: Optional[str] = None,
        device: Optional[Union[str, torch.device]] = None,
        cache_dir: Optional[str] = None,
)

Source from the content-addressed store, hash-verified

22
23
24def load_openai_model(
25 name: str,
26 precision: Optional[str] = None,
27 device: Optional[Union[str, torch.device]] = None,
28 cache_dir: Optional[str] = None,
29):
30 """Load a CLIP model
31
32 Parameters
33 ----------
34 name : str
35 A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 precision: str
37 Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 device : Union[str, torch.device]
39 The device to put the loaded model
40 cache_dir : Optional[str]
41 The directory to cache the downloaded model weights
42
43 Returns
44 -------
45 model : torch.nn.Module
46 The CLIP model
47 preprocess : Callable[[PIL.Image], torch.Tensor]
48 A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49 """
50 if device is None:
51 device = "cuda" if torch.cuda.is_available() else "cpu"
52 if precision is None:
53 precision = 'fp32' if device == 'cpu' else 'fp16'
54
55 if get_pretrained_url(name, 'openai'):
56 model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
57 elif os.path.isfile(name):
58 model_path = name
59 else:
60 raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
61
62 try:
63 # loading JIT archive
64 model = torch.jit.load(model_path, map_location="cpu").eval()
65 state_dict = None
66 except RuntimeError:
67 # loading saved state dict
68 state_dict = torch.load(model_path, map_location="cpu")
69
70 # Build a non-jit model from the OpenAI jitted model state dict
71 cast_dtype = get_cast_dtype(precision)
72 try:
73 model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
74 except KeyError:
75 sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
76 model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
77
78 # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
79 model = model.to(device)
80 # FIXME support pure fp16/bf16 precision modes
81 if precision != 'fp16':

Callers

nothing calls this directly

Calls 6

get_pretrained_urlFunction · 0.85
list_openai_modelsFunction · 0.85
get_cast_dtypeFunction · 0.85
convert_weights_to_lpFunction · 0.85

Tested by

no test coverage detected