(self)
| 165 | |
| 166 | @skip_mps |
| 167 | def test_serialization(self): |
| 168 | unet, ema_unet = self.get_models() |
| 169 | noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() |
| 170 | |
| 171 | with tempfile.TemporaryDirectory() as tmpdir: |
| 172 | ema_unet.save_pretrained(tmpdir) |
| 173 | loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) |
| 174 | loaded_unet = loaded_unet.to(unet.device) |
| 175 | |
| 176 | # Since no EMA step has been performed the outputs should match. |
| 177 | output = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 178 | output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 179 | |
| 180 | assert torch.allclose(output, output_loaded, atol=1e-4) |
| 181 | |
| 182 | |
| 183 | class EMAModelTestsForeach(unittest.TestCase): |
nothing calls this directly
no test coverage detected