Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict precision: str Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Un
(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
cache_dir: Optional[str] = None,
)
| 22 | |
| 23 | |
| 24 | def load_openai_model( |
| 25 | name: str, |
| 26 | precision: Optional[str] = None, |
| 27 | device: Optional[Union[str, torch.device]] = None, |
| 28 | cache_dir: Optional[str] = None, |
| 29 | ): |
| 30 | """Load a CLIP model |
| 31 | |
| 32 | Parameters |
| 33 | ---------- |
| 34 | name : str |
| 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict |
| 36 | precision: str |
| 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. |
| 38 | device : Union[str, torch.device] |
| 39 | The device to put the loaded model |
| 40 | cache_dir : Optional[str] |
| 41 | The directory to cache the downloaded model weights |
| 42 | |
| 43 | Returns |
| 44 | ------- |
| 45 | model : torch.nn.Module |
| 46 | The CLIP model |
| 47 | preprocess : Callable[[PIL.Image], torch.Tensor] |
| 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input |
| 49 | """ |
| 50 | if device is None: |
| 51 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 52 | if precision is None: |
| 53 | precision = 'fp32' if device == 'cpu' else 'fp16' |
| 54 | |
| 55 | if get_pretrained_url(name, 'openai'): |
| 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) |
| 57 | elif os.path.isfile(name): |
| 58 | model_path = name |
| 59 | else: |
| 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") |
| 61 | |
| 62 | try: |
| 63 | # loading JIT archive |
| 64 | model = torch.jit.load(model_path, map_location="cpu").eval() |
| 65 | state_dict = None |
| 66 | except RuntimeError: |
| 67 | # loading saved state dict |
| 68 | state_dict = torch.load(model_path, map_location="cpu") |
| 69 | |
| 70 | # Build a non-jit model from the OpenAI jitted model state dict |
| 71 | cast_dtype = get_cast_dtype(precision) |
| 72 | try: |
| 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) |
| 74 | except KeyError: |
| 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} |
| 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) |
| 77 | |
| 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use |
| 79 | model = model.to(device) |
| 80 | # FIXME support pure fp16/bf16 precision modes |
| 81 | if precision != 'fp16': |
nothing calls this directly
no test coverage detected