(self: BaseTest)
| 140 | """ |
| 141 | |
| 142 | def jit_test_assert(self: BaseTest) -> None: |
| 143 | model_1 = model |
| 144 | attr_args = args |
| 145 | if ( |
| 146 | mode is JITCompareMode.data_parallel_jit_trace |
| 147 | or JITCompareMode.data_parallel_jit_script |
| 148 | ): |
| 149 | if not torch.cuda.is_available() or torch.cuda.device_count() == 0: |
| 150 | raise unittest.SkipTest( |
| 151 | "Skipping GPU test since CUDA not available." |
| 152 | ) |
| 153 | # Construct cuda_args, moving all tensor inputs in args to CUDA device |
| 154 | cuda_args = {} |
| 155 | for key in args: |
| 156 | if isinstance(args[key], Tensor): |
| 157 | cuda_args[key] = args[key].cuda() |
| 158 | elif isinstance(args[key], tuple): |
| 159 | cuda_args[key] = tuple( |
| 160 | elem.cuda() if isinstance(elem, Tensor) else elem |
| 161 | for elem in args[key] |
| 162 | ) |
| 163 | else: |
| 164 | cuda_args[key] = args[key] |
| 165 | attr_args = cuda_args |
| 166 | model_1 = model_1.cuda() |
| 167 | |
| 168 | # Initialize models based on JITCompareMode |
| 169 | if ( |
| 170 | mode is JITCompareMode.cpu_jit_script |
| 171 | or JITCompareMode.data_parallel_jit_script |
| 172 | ): |
| 173 | model_2 = torch.jit.script(model_1) # type: ignore |
| 174 | elif ( |
| 175 | mode is JITCompareMode.cpu_jit_trace |
| 176 | or JITCompareMode.data_parallel_jit_trace |
| 177 | ): |
| 178 | formatted_inputs = _format_tensor_into_tuples(args["inputs"]) |
| 179 | additional_args: Tuple[Any, ...] = ( |
| 180 | _format_additional_forward_args(args["additional_forward_args"]) |
| 181 | if "additional_forward_args" in args |
| 182 | and args["additional_forward_args"] is not None |
| 183 | else () |
| 184 | ) |
| 185 | if formatted_inputs is not None: |
| 186 | all_inps = formatted_inputs + additional_args |
| 187 | else: |
| 188 | all_inps = additional_args |
| 189 | model_2 = torch.jit.trace(model_1, all_inps) # type: ignore |
| 190 | else: |
| 191 | raise AssertionError("JIT compare mode type is not valid.") |
| 192 | |
| 193 | attr_method_1 = algorithm(model_1) |
| 194 | attr_method_2 = algorithm(model_2) |
| 195 | |
| 196 | if noise_tunnel: |
| 197 | attr_method_1 = NoiseTunnel(attr_method_1) |
| 198 | attr_method_2 = NoiseTunnel(attr_method_2) |
| 199 | if attr_method_1.has_convergence_delta(): |
nothing calls this directly
no test coverage detected