MCPcopy
hub / github.com/meta-pytorch/captum / jit_test_assert

Method jit_test_assert

tests/attr/test_jit.py:142–217  ·  view source on GitHub ↗
(self: BaseTest)

Source from the content-addressed store, hash-verified

140 """
141
142 def jit_test_assert(self: BaseTest) -> None:
143 model_1 = model
144 attr_args = args
145 if (
146 mode is JITCompareMode.data_parallel_jit_trace
147 or JITCompareMode.data_parallel_jit_script
148 ):
149 if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
150 raise unittest.SkipTest(
151 "Skipping GPU test since CUDA not available."
152 )
153 # Construct cuda_args, moving all tensor inputs in args to CUDA device
154 cuda_args = {}
155 for key in args:
156 if isinstance(args[key], Tensor):
157 cuda_args[key] = args[key].cuda()
158 elif isinstance(args[key], tuple):
159 cuda_args[key] = tuple(
160 elem.cuda() if isinstance(elem, Tensor) else elem
161 for elem in args[key]
162 )
163 else:
164 cuda_args[key] = args[key]
165 attr_args = cuda_args
166 model_1 = model_1.cuda()
167
168 # Initialize models based on JITCompareMode
169 if (
170 mode is JITCompareMode.cpu_jit_script
171 or JITCompareMode.data_parallel_jit_script
172 ):
173 model_2 = torch.jit.script(model_1) # type: ignore
174 elif (
175 mode is JITCompareMode.cpu_jit_trace
176 or JITCompareMode.data_parallel_jit_trace
177 ):
178 formatted_inputs = _format_tensor_into_tuples(args["inputs"])
179 additional_args: Tuple[Any, ...] = (
180 _format_additional_forward_args(args["additional_forward_args"])
181 if "additional_forward_args" in args
182 and args["additional_forward_args"] is not None
183 else ()
184 )
185 if formatted_inputs is not None:
186 all_inps = formatted_inputs + additional_args
187 else:
188 all_inps = additional_args
189 model_2 = torch.jit.trace(model_1, all_inps) # type: ignore
190 else:
191 raise AssertionError("JIT compare mode type is not valid.")
192
193 attr_method_1 = algorithm(model_1)
194 attr_method_2 = algorithm(model_2)
195
196 if noise_tunnel:
197 attr_method_1 = NoiseTunnel(attr_method_1)
198 attr_method_2 = NoiseTunnel(attr_method_2)
199 if attr_method_1.has_convergence_delta():

Callers

nothing calls this directly

Calls 7

has_convergence_deltaMethod · 0.95
attributeMethod · 0.95
NoiseTunnelClass · 0.90
setUpMethod · 0.45

Tested by

no test coverage detected