MCPcopy
hub / github.com/pytorch/vision / test_segmentation_model

Function test_segmentation_model

test/test_models.py:721–781  ·  view source on GitHub ↗
(model_fn, dev)

Source from the content-addressed store, hash-verified

719@pytest.mark.parametrize("model_fn", list_model_fns(models.segmentation))
720@pytest.mark.parametrize("dev", cpu_and_cuda())
721def test_segmentation_model(model_fn, dev):
722 set_rng_seed(0)
723 defaults = {
724 "num_classes": 10,
725 "weights_backbone": None,
726 "input_shape": (1, 3, 32, 32),
727 }
728 model_name = model_fn.__name__
729 kwargs = {**defaults, **_model_params.get(model_name, {})}
730 input_shape = kwargs.pop("input_shape")
731
732 model = model_fn(**kwargs)
733 model.eval().to(device=dev)
734 # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
735 x = torch.rand(input_shape).to(device=dev)
736 with torch.no_grad(), freeze_rng_state():
737 out = model(x)
738
739 def check_out(out):
740 prec = 0.01
741 try:
742 # We first try to assert the entire output if possible. This is not
743 # only the best way to assert results but also handles the cases
744 # where we need to create a new expected result.
745 _assert_expected(out.cpu(), model_name, prec=prec)
746 except AssertionError:
747 # Unfortunately some segmentation models are flaky with autocast
748 # so instead of validating the probability scores, check that the class
749 # predictions match.
750 expected_file = _get_expected_file(model_name)
751 expected = torch.load(expected_file, weights_only=True)
752 torch.testing.assert_close(
753 out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
754 )
755 return False # Partial validation performed
756
757 return True # Full validation performed
758
759 full_validation = check_out(out["out"])
760
761 _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
762 _check_fx_compatible(model, x, eager_out=out)
763
764 if dev == "cuda":
765 with torch.cuda.amp.autocast(), torch.no_grad(), freeze_rng_state():
766 out = model(x)
767 # See autocast_flaky_numerics comment at top of file.
768 if model_name not in autocast_flaky_numerics:
769 full_validation &= check_out(out["out"])
770
771 if not full_validation:
772 msg = (
773 f"The output of {test_segmentation_model.__name__} could only be partially validated. "
774 "This is likely due to unit-test flakiness, but you may "
775 "want to do additional manual checks if you made "
776 "significant changes to the codebase."
777 )
778 warnings.warn(msg, RuntimeWarning)

Callers

nothing calls this directly

Calls 8

set_rng_seedFunction · 0.90
freeze_rng_stateFunction · 0.90
check_outFunction · 0.85
_check_jit_scriptableFunction · 0.85
_check_fx_compatibleFunction · 0.85
_check_input_backpropFunction · 0.85
getMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…