(self)
| 203 | self.assertTrue(torch.allclose(output1, output2)) |
| 204 | |
| 205 | def test_inference(self): |
| 206 | registry = HookRegistry.check_if_exists_or_initialize(self.model) |
| 207 | registry.register_hook(AddHook(1), "add_hook") |
| 208 | registry.register_hook(MultiplyHook(2), "multiply_hook") |
| 209 | |
| 210 | input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) |
| 211 | output1 = self.model(input).mean().detach().cpu().item() |
| 212 | |
| 213 | registry.remove_hook("multiply_hook") |
| 214 | new_input = input * 2 |
| 215 | output2 = self.model(new_input).mean().detach().cpu().item() |
| 216 | |
| 217 | registry.remove_hook("add_hook") |
| 218 | new_input = input * 2 + 1 |
| 219 | output3 = self.model(new_input).mean().detach().cpu().item() |
| 220 | |
| 221 | self.assertAlmostEqual(output1, output2, places=5) |
| 222 | self.assertAlmostEqual(output1, output3, places=5) |
| 223 | self.assertAlmostEqual(output2, output3, places=5) |
| 224 | |
| 225 | def test_skip_layer_hook(self): |
| 226 | registry = HookRegistry.check_if_exists_or_initialize(self.model) |
nothing calls this directly
no test coverage detected