MCPcopy
hub / github.com/pytorch/vision / test_quantized_classification_model

Function test_quantized_classification_model

test/test_models.py:967–1012  ·  view source on GitHub ↗
(model_fn)

Source from the content-addressed store, hash-verified

965)
966@pytest.mark.parametrize("model_fn", list_model_fns(models.quantization))
967def test_quantized_classification_model(model_fn):
968 set_rng_seed(0)
969 defaults = {
970 "num_classes": 5,
971 "input_shape": (1, 3, 224, 224),
972 "quantize": True,
973 }
974 model_name = model_fn.__name__
975 kwargs = {**defaults, **_model_params.get(model_name, {})}
976 input_shape = kwargs.pop("input_shape")
977
978 # First check if quantize=True provides models that can run with input data
979 model = model_fn(**kwargs)
980 model.eval()
981 x = torch.rand(input_shape)
982 out = model(x)
983
984 if model_name not in quantized_flaky_models:
985 _assert_expected(out.cpu(), model_name + "_quantized", prec=2e-2)
986 assert out.shape[-1] == 5
987 _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
988 _check_fx_compatible(model, x, eager_out=out)
989 else:
990 try:
991 torch.jit.script(model)
992 except Exception as e:
993 raise AssertionError("model cannot be scripted.") from e
994
995 kwargs["quantize"] = False
996 for eval_mode in [True, False]:
997 model = model_fn(**kwargs)
998 if eval_mode:
999 model.eval()
1000 model.qconfig = torch.ao.quantization.default_qconfig
1001 else:
1002 model.train()
1003 model.qconfig = torch.ao.quantization.default_qat_qconfig
1004
1005 model.fuse_model(is_qat=not eval_mode)
1006 if eval_mode:
1007 torch.ao.quantization.prepare(model, inplace=True)
1008 else:
1009 torch.ao.quantization.prepare_qat(model, inplace=True)
1010 model.eval()
1011
1012 torch.ao.quantization.convert(model, inplace=True)
1013
1014
1015@pytest.mark.parametrize("model_fn", list_model_fns(models.detection))

Callers

nothing calls this directly

Calls 8

set_rng_seedFunction · 0.90
_assert_expectedFunction · 0.85
_check_jit_scriptableFunction · 0.85
_check_fx_compatibleFunction · 0.85
getMethod · 0.80
trainMethod · 0.80
prepareMethod · 0.80
fuse_modelMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…