(model_fn, dev)
| 676 | @pytest.mark.parametrize("model_fn", list_model_fns(models)) |
| 677 | @pytest.mark.parametrize("dev", cpu_and_cuda()) |
| 678 | def test_classification_model(model_fn, dev): |
| 679 | set_rng_seed(0) |
| 680 | defaults = { |
| 681 | "num_classes": 50, |
| 682 | "input_shape": (1, 3, 224, 224), |
| 683 | } |
| 684 | model_name = model_fn.__name__ |
| 685 | if SKIP_BIG_MODEL and is_skippable(model_name, dev): |
| 686 | pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model") |
| 687 | kwargs = {**defaults, **_model_params.get(model_name, {})} |
| 688 | num_classes = kwargs.get("num_classes") |
| 689 | input_shape = kwargs.pop("input_shape") |
| 690 | real_image = kwargs.pop("real_image", False) |
| 691 | |
| 692 | model = model_fn(**kwargs) |
| 693 | model.eval().to(device=dev) |
| 694 | x = _get_image(input_shape=input_shape, real_image=real_image, device=dev) |
| 695 | out = model(x) |
| 696 | # FIXME: this if/else is nasty and only here to please our CI prior to the |
| 697 | # release. We rethink these tests altogether. |
| 698 | if model_name == "resnet101": |
| 699 | prec = 0.2 |
| 700 | else: |
| 701 | # FIXME: this is probably still way too high. |
| 702 | prec = 0.1 |
| 703 | _assert_expected(out.cpu(), model_name, prec=prec) |
| 704 | assert out.shape[-1] == num_classes |
| 705 | _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) |
| 706 | _check_fx_compatible(model, x, eager_out=out) |
| 707 | |
| 708 | if dev == "cuda": |
| 709 | with torch.cuda.amp.autocast(): |
| 710 | out = model(x) |
| 711 | # See autocast_flaky_numerics comment at top of file. |
| 712 | if model_name not in autocast_flaky_numerics: |
| 713 | _assert_expected(out.cpu(), model_name, prec=0.1) |
| 714 | assert out.shape[-1] == 50 |
| 715 | |
| 716 | _check_input_backprop(model, x) |
| 717 | |
| 718 | |
| 719 | @pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation)) |
no test coverage detected
searching dependent graphs…