MCPcopy
hub / github.com/huggingface/pytorch-image-models / validate

Function validate

validate.py:172–453  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

170
171
172def validate(args):
173 # might as well try to validate something
174 args.pretrained = args.pretrained or not args.checkpoint
175 args.prefetcher = not args.no_prefetcher
176
177 if torch.cuda.is_available():
178 torch.backends.cuda.matmul.allow_tf32 = True
179 torch.backends.cudnn.benchmark = True
180
181 device = torch.device(args.device)
182
183 if args.metrics_avg and not has_sklearn:
184 _logger.warning(
185 f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'.")
186 args.metrics_avg = None
187
188 model_dtype = None
189 if args.model_dtype:
190 assert args.model_dtype in ('float32', 'float16', 'bfloat16')
191 model_dtype = getattr(torch, args.model_dtype)
192
193 # resolve AMP arguments based on PyTorch availability
194 amp_autocast = suppress
195 if args.amp:
196 assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
197 assert args.amp_dtype in ('float16', 'bfloat16')
198 amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
199 amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
200 _logger.info('Validating in mixed precision with native PyTorch AMP.')
201 else:
202 _logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
203
204 if args.fuser:
205 set_jit_fuser(args.fuser)
206
207 if args.fast_norm:
208 set_fast_norm()
209
210 # create model
211 in_chans = 3
212 if args.in_chans is not None:
213 in_chans = args.in_chans
214 elif args.input_size is not None:
215 in_chans = args.input_size[0]
216
217 model = create_model(
218 args.model,
219 pretrained=args.pretrained,
220 num_classes=args.num_classes,
221 in_chans=in_chans,
222 global_pool=args.gp,
223 scriptable=args.torchscript,
224 **args.model_kwargs,
225 )
226 if args.num_classes is None:
227 assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
228 args.num_classes = model.num_classes
229

Callers 2

_try_runFunction · 0.70
mainFunction · 0.70

Calls 15

add_resultMethod · 0.95
updateMethod · 0.95
get_accuracyMethod · 0.95
set_jit_fuserFunction · 0.90
set_fast_normFunction · 0.90
create_modelFunction · 0.90
load_checkpointFunction · 0.90
reparameterize_modelFunction · 0.90
resolve_data_configFunction · 0.90
apply_test_time_poolFunction · 0.90
create_datasetFunction · 0.90
RealLabelsImagenetClass · 0.90

Tested by

no test coverage detected