MCPcopy Index your code
hub / github.com/huggingface/diffusers / test_inference

Method test_inference

tests/hooks/test_hooks.py:205–223  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

203 self.assertTrue(torch.allclose(output1, output2))
204
205 def test_inference(self):
206 registry = HookRegistry.check_if_exists_or_initialize(self.model)
207 registry.register_hook(AddHook(1), "add_hook")
208 registry.register_hook(MultiplyHook(2), "multiply_hook")
209
210 input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
211 output1 = self.model(input).mean().detach().cpu().item()
212
213 registry.remove_hook("multiply_hook")
214 new_input = input * 2
215 output2 = self.model(new_input).mean().detach().cpu().item()
216
217 registry.remove_hook("add_hook")
218 new_input = input * 2 + 1
219 output3 = self.model(new_input).mean().detach().cpu().item()
220
221 self.assertAlmostEqual(output1, output2, places=5)
222 self.assertAlmostEqual(output1, output3, places=5)
223 self.assertAlmostEqual(output2, output3, places=5)
224
225 def test_skip_layer_hook(self):
226 registry = HookRegistry.check_if_exists_or_initialize(self.model)

Callers

nothing calls this directly

Calls 7

get_generatorMethod · 0.95
AddHookClass · 0.85
MultiplyHookClass · 0.85
register_hookMethod · 0.80
remove_hookMethod · 0.80
modelMethod · 0.45

Tested by

no test coverage detected