(self)
| 320 | |
| 321 | @skip_mps |
| 322 | def test_serialization(self): |
| 323 | unet, ema_unet = self.get_models() |
| 324 | noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() |
| 325 | |
| 326 | with tempfile.TemporaryDirectory() as tmpdir: |
| 327 | ema_unet.save_pretrained(tmpdir) |
| 328 | loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) |
| 329 | loaded_unet = loaded_unet.to(unet.device) |
| 330 | |
| 331 | # Since no EMA step has been performed the outputs should match. |
| 332 | output = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 333 | output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 334 | |
| 335 | assert torch.allclose(output, output_loaded, atol=1e-4) |
nothing calls this directly
no test coverage detected