MCPcopy
hub / github.com/pytorch/vision / test_classification_model

Function test_classification_model

test/test_models.py:678–716  ·  view source on GitHub ↗
(model_fn, dev)

Source from the content-addressed store, hash-verified

676@pytest.mark.parametrize("model_fn", list_model_fns(models))
677@pytest.mark.parametrize("dev", cpu_and_cuda())
678def test_classification_model(model_fn, dev):
679 set_rng_seed(0)
680 defaults = {
681 "num_classes": 50,
682 "input_shape": (1, 3, 224, 224),
683 }
684 model_name = model_fn.__name__
685 if SKIP_BIG_MODEL and is_skippable(model_name, dev):
686 pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
687 kwargs = {**defaults, **_model_params.get(model_name, {})}
688 num_classes = kwargs.get("num_classes")
689 input_shape = kwargs.pop("input_shape")
690 real_image = kwargs.pop("real_image", False)
691
692 model = model_fn(**kwargs)
693 model.eval().to(device=dev)
694 x = _get_image(input_shape=input_shape, real_image=real_image, device=dev)
695 out = model(x)
696 # FIXME: this if/else is nasty and only here to please our CI prior to the
697 # release. We rethink these tests altogether.
698 if model_name == "resnet101":
699 prec = 0.2
700 else:
701 # FIXME: this is probably still way too high.
702 prec = 0.1
703 _assert_expected(out.cpu(), model_name, prec=prec)
704 assert out.shape[-1] == num_classes
705 _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
706 _check_fx_compatible(model, x, eager_out=out)
707
708 if dev == "cuda":
709 with torch.cuda.amp.autocast():
710 out = model(x)
711 # See autocast_flaky_numerics comment at top of file.
712 if model_name not in autocast_flaky_numerics:
713 _assert_expected(out.cpu(), model_name, prec=0.1)
714 assert out.shape[-1] == 50
715
716 _check_input_backprop(model, x)
717
718
719@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))

Callers 1

test_vitc_modelsFunction · 0.85

Calls 9

set_rng_seedFunction · 0.90
is_skippableFunction · 0.85
_get_imageFunction · 0.85
_assert_expectedFunction · 0.85
_check_jit_scriptableFunction · 0.85
_check_fx_compatibleFunction · 0.85
_check_input_backpropFunction · 0.85
getMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…