MCPcopy
hub / github.com/pytorch/vision / disable_weight_loading

Function disable_weight_loading

test/test_models.py:71–112  ·  view source on GitHub ↗

When testing models, the two slowest operations are the downloading of the weights to a file and loading them into the model. Unless, you want to test against specific weights, these steps can be disabled without any drawbacks. Including this fixture into the signature of your test, i.e

(mocker)

Source from the content-addressed store, hash-verified

69
70@pytest.fixture
71def disable_weight_loading(mocker):
72 """When testing models, the two slowest operations are the downloading of the weights to a file and loading them
73 into the model. Unless, you want to test against specific weights, these steps can be disabled without any
74 drawbacks.
75
76 Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
77 through all models in `torchvision.models` and will patch all occurrences of the function
78 `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
79 no-ops.
80
81 .. warning:
82
83 Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
84 fixture if you want to compare the model output against reference values.
85
86 """
87 starting_point = models
88 function_name = "load_state_dict_from_url"
89 method_name = "load_state_dict"
90
91 module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
92 targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
93 for name in module_names:
94 module = sys.modules.get(name)
95 if not module:
96 continue
97
98 if function_name in module.__dict__:
99 targets.add(f"{module.__name__}.{function_name}")
100
101 targets.update(
102 {
103 f"{module.__name__}.{obj.__name__}.{method_name}"
104 for obj in module.__dict__.values()
105 if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
106 }
107 )
108
109 for target in targets:
110 # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
111 with contextlib.suppress(AttributeError):
112 mocker.patch(target)
113
114
115def _get_expected_file(name=None):

Callers

nothing calls this directly

Calls 2

getMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…