| 181 | |
| 182 | |
| 183 | def create_model_and_transforms( |
| 184 | model_name: str, |
| 185 | pretrained: str = '', |
| 186 | precision: str = 'fp32', |
| 187 | device: torch.device = torch.device('cpu'), |
| 188 | jit: bool = False, |
| 189 | force_quick_gelu: bool = False, |
| 190 | pretrained_image: bool = False, |
| 191 | image_mean: Optional[Tuple[float, ...]] = None, |
| 192 | image_std: Optional[Tuple[float, ...]] = None, |
| 193 | cache_dir: Optional[str] = None, |
| 194 | args=None, |
| 195 | ): |
| 196 | model = create_model( |
| 197 | model_name, pretrained, precision, device, jit, |
| 198 | force_quick_gelu=force_quick_gelu, |
| 199 | pretrained_image=pretrained_image, |
| 200 | cache_dir=cache_dir, |
| 201 | args=args) |
| 202 | |
| 203 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) |
| 204 | image_std = image_std or getattr(model.visual, 'image_std', None) |
| 205 | val_keep_ratio = 'davit' not in model_name.lower() |
| 206 | preprocess_train = image_transform( |
| 207 | model.visual.image_size, is_train=True, mean=image_mean, std=image_std) |
| 208 | preprocess_val = image_transform(model.visual.image_size, is_train=False, |
| 209 | mean=image_mean, std=image_std, val_keep_ratio=val_keep_ratio) |
| 210 | |
| 211 | return model, preprocess_train, preprocess_val |
| 212 | |
| 213 | |
| 214 | def list_models(): |