When testing models, the two slowest operations are the downloading of the weights to a file and loading them into the model. Unless, you want to test against specific weights, these steps can be disabled without any drawbacks. Including this fixture into the signature of your test, i.e
(mocker)
| 69 | |
| 70 | @pytest.fixture |
| 71 | def disable_weight_loading(mocker): |
| 72 | """When testing models, the two slowest operations are the downloading of the weights to a file and loading them |
| 73 | into the model. Unless, you want to test against specific weights, these steps can be disabled without any |
| 74 | drawbacks. |
| 75 | |
| 76 | Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse |
| 77 | through all models in `torchvision.models` and will patch all occurrences of the function |
| 78 | `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be |
| 79 | no-ops. |
| 80 | |
| 81 | .. warning: |
| 82 | |
| 83 | Loaded models are still executable as normal, but will always have random weights. Make sure to not use this |
| 84 | fixture if you want to compare the model output against reference values. |
| 85 | |
| 86 | """ |
| 87 | starting_point = models |
| 88 | function_name = "load_state_dict_from_url" |
| 89 | method_name = "load_state_dict" |
| 90 | |
| 91 | module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")} |
| 92 | targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"} |
| 93 | for name in module_names: |
| 94 | module = sys.modules.get(name) |
| 95 | if not module: |
| 96 | continue |
| 97 | |
| 98 | if function_name in module.__dict__: |
| 99 | targets.add(f"{module.__name__}.{function_name}") |
| 100 | |
| 101 | targets.update( |
| 102 | { |
| 103 | f"{module.__name__}.{obj.__name__}.{method_name}" |
| 104 | for obj in module.__dict__.values() |
| 105 | if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__ |
| 106 | } |
| 107 | ) |
| 108 | |
| 109 | for target in targets: |
| 110 | # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details |
| 111 | with contextlib.suppress(AttributeError): |
| 112 | mocker.patch(target) |
| 113 | |
| 114 | |
| 115 | def _get_expected_file(name=None): |