(model_fn, dev)
| 719 | @pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation)) |
| 720 | @pytest.mark.parametrize("dev", cpu_and_cuda()) |
| 721 | def test_segmentation_model(model_fn, dev): |
| 722 | set_rng_seed(0) |
| 723 | defaults = { |
| 724 | "num_classes": 10, |
| 725 | "weights_backbone": None, |
| 726 | "input_shape": (1, 3, 32, 32), |
| 727 | } |
| 728 | model_name = model_fn.__name__ |
| 729 | kwargs = {**defaults, **_model_params.get(model_name, {})} |
| 730 | input_shape = kwargs.pop("input_shape") |
| 731 | |
| 732 | model = model_fn(**kwargs) |
| 733 | model.eval().to(device=dev) |
| 734 | # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests |
| 735 | x = torch.rand(input_shape).to(device=dev) |
| 736 | with torch.no_grad(), freeze_rng_state(): |
| 737 | out = model(x) |
| 738 | |
| 739 | def check_out(out): |
| 740 | prec = 0.01 |
| 741 | try: |
| 742 | # We first try to assert the entire output if possible. This is not |
| 743 | # only the best way to assert results but also handles the cases |
| 744 | # where we need to create a new expected result. |
| 745 | _assert_expected(out.cpu(), model_name, prec=prec) |
| 746 | except AssertionError: |
| 747 | # Unfortunately some segmentation models are flaky with autocast |
| 748 | # so instead of validating the probability scores, check that the class |
| 749 | # predictions match. |
| 750 | expected_file = _get_expected_file(model_name) |
| 751 | expected = torch.load(expected_file, weights_only=True) |
| 752 | torch.testing.assert_close( |
| 753 | out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False |
| 754 | ) |
| 755 | return False # Partial validation performed |
| 756 | |
| 757 | return True # Full validation performed |
| 758 | |
| 759 | full_validation = check_out(out["out"]) |
| 760 | |
| 761 | _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) |
| 762 | _check_fx_compatible(model, x, eager_out=out) |
| 763 | |
| 764 | if dev == "cuda": |
| 765 | with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state(): |
| 766 | out = model(x) |
| 767 | # See autocast_flaky_numerics comment at top of file. |
| 768 | if model_name not in autocast_flaky_numerics: |
| 769 | full_validation &= check_out(out["out"]) |
| 770 | |
| 771 | if not full_validation: |
| 772 | msg = ( |
| 773 | f"The output of {test_segmentation_model.__name__} could only be partially validated. " |
| 774 | "This is likely due to unit-test flakiness, but you may " |
| 775 | "want to do additional manual checks if you made " |
| 776 | "significant changes to the codebase." |
| 777 | ) |
| 778 | warnings.warn(msg, RuntimeWarning) |
nothing calls this directly
no test coverage detected
searching dependent graphs…