MCPcopy
hub / github.com/huggingface/diffusers / test_serialization

Method test_serialization

tests/others/test_ema.py:167–180  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

165
166 @skip_mps
167 def test_serialization(self):
168 unet, ema_unet = self.get_models()
169 noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs()
170
171 with tempfile.TemporaryDirectory() as tmpdir:
172 ema_unet.save_pretrained(tmpdir)
173 loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel)
174 loaded_unet = loaded_unet.to(unet.device)
175
176 # Since no EMA step has been performed the outputs should match.
177 output = unet(noisy_latents, timesteps, encoder_hidden_states).sample
178 output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample
179
180 assert torch.allclose(output, output_loaded, atol=1e-4)
181
182
183class EMAModelTestsForeach(unittest.TestCase):

Callers

nothing calls this directly

Calls 6

get_modelsMethod · 0.95
get_dummy_inputsMethod · 0.95
unetFunction · 0.85
save_pretrainedMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected