(model_fn)
| 965 | ) |
| 966 | @pytest.mark.parametrize("model_fn", list_model_fns(models.quantization)) |
| 967 | def test_quantized_classification_model(model_fn): |
| 968 | set_rng_seed(0) |
| 969 | defaults = { |
| 970 | "num_classes": 5, |
| 971 | "input_shape": (1, 3, 224, 224), |
| 972 | "quantize": True, |
| 973 | } |
| 974 | model_name = model_fn.__name__ |
| 975 | kwargs = {**defaults, **_model_params.get(model_name, {})} |
| 976 | input_shape = kwargs.pop("input_shape") |
| 977 | |
| 978 | # First check if quantize=True provides models that can run with input data |
| 979 | model = model_fn(**kwargs) |
| 980 | model.eval() |
| 981 | x = torch.rand(input_shape) |
| 982 | out = model(x) |
| 983 | |
| 984 | if model_name not in quantized_flaky_models: |
| 985 | _assert_expected(out.cpu(), model_name + "_quantized", prec=2e-2) |
| 986 | assert out.shape[-1] == 5 |
| 987 | _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) |
| 988 | _check_fx_compatible(model, x, eager_out=out) |
| 989 | else: |
| 990 | try: |
| 991 | torch.jit.script(model) |
| 992 | except Exception as e: |
| 993 | raise AssertionError("model cannot be scripted.") from e |
| 994 | |
| 995 | kwargs["quantize"] = False |
| 996 | for eval_mode in [True, False]: |
| 997 | model = model_fn(**kwargs) |
| 998 | if eval_mode: |
| 999 | model.eval() |
| 1000 | model.qconfig = torch.ao.quantization.default_qconfig |
| 1001 | else: |
| 1002 | model.train() |
| 1003 | model.qconfig = torch.ao.quantization.default_qat_qconfig |
| 1004 | |
| 1005 | model.fuse_model(is_qat=not eval_mode) |
| 1006 | if eval_mode: |
| 1007 | torch.ao.quantization.prepare(model, inplace=True) |
| 1008 | else: |
| 1009 | torch.ao.quantization.prepare_qat(model, inplace=True) |
| 1010 | model.eval() |
| 1011 | |
| 1012 | torch.ao.quantization.convert(model, inplace=True) |
| 1013 | |
| 1014 | |
| 1015 | @pytest.mark.parametrize("model_fn", list_model_fns(models.detection)) |
nothing calls this directly
no test coverage detected
searching dependent graphs…