(args)
| 170 | |
| 171 | |
| 172 | def validate(args): |
| 173 | # might as well try to validate something |
| 174 | args.pretrained = args.pretrained or not args.checkpoint |
| 175 | args.prefetcher = not args.no_prefetcher |
| 176 | |
| 177 | if torch.cuda.is_available(): |
| 178 | torch.backends.cuda.matmul.allow_tf32 = True |
| 179 | torch.backends.cudnn.benchmark = True |
| 180 | |
| 181 | device = torch.device(args.device) |
| 182 | |
| 183 | if args.metrics_avg and not has_sklearn: |
| 184 | _logger.warning( |
| 185 | f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'.") |
| 186 | args.metrics_avg = None |
| 187 | |
| 188 | model_dtype = None |
| 189 | if args.model_dtype: |
| 190 | assert args.model_dtype in ('float32', 'float16', 'bfloat16') |
| 191 | model_dtype = getattr(torch, args.model_dtype) |
| 192 | |
| 193 | # resolve AMP arguments based on PyTorch availability |
| 194 | amp_autocast = suppress |
| 195 | if args.amp: |
| 196 | assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' |
| 197 | assert args.amp_dtype in ('float16', 'bfloat16') |
| 198 | amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 |
| 199 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) |
| 200 | _logger.info('Validating in mixed precision with native PyTorch AMP.') |
| 201 | else: |
| 202 | _logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.') |
| 203 | |
| 204 | if args.fuser: |
| 205 | set_jit_fuser(args.fuser) |
| 206 | |
| 207 | if args.fast_norm: |
| 208 | set_fast_norm() |
| 209 | |
| 210 | # create model |
| 211 | in_chans = 3 |
| 212 | if args.in_chans is not None: |
| 213 | in_chans = args.in_chans |
| 214 | elif args.input_size is not None: |
| 215 | in_chans = args.input_size[0] |
| 216 | |
| 217 | model = create_model( |
| 218 | args.model, |
| 219 | pretrained=args.pretrained, |
| 220 | num_classes=args.num_classes, |
| 221 | in_chans=in_chans, |
| 222 | global_pool=args.gp, |
| 223 | scriptable=args.torchscript, |
| 224 | **args.model_kwargs, |
| 225 | ) |
| 226 | if args.num_classes is None: |
| 227 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' |
| 228 | args.num_classes = model.num_classes |
| 229 |
no test coverage detected