MCPcopy
hub / github.com/huggingface/diffusers / test_serialization

Method test_serialization

tests/others/test_ema.py:322–335  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

320
321 @skip_mps
322 def test_serialization(self):
323 unet, ema_unet = self.get_models()
324 noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs()
325
326 with tempfile.TemporaryDirectory() as tmpdir:
327 ema_unet.save_pretrained(tmpdir)
328 loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel)
329 loaded_unet = loaded_unet.to(unet.device)
330
331 # Since no EMA step has been performed the outputs should match.
332 output = unet(noisy_latents, timesteps, encoder_hidden_states).sample
333 output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
334
335 assert torch.allclose(output, output_loaded, atol=1e-4)

Callers

nothing calls this directly

Calls 6

get_modelsMethod · 0.95
get_dummy_inputsMethod · 0.95
unetFunction · 0.85
save_pretrainedMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected