MCPcopy Index your code
hub / github.com/huggingface/diffusers / load_model_hook

Function load_model_hook

examples/text_to_image/train_text_to_image.py:672–693  ·  view source on GitHub ↗
(models, input_dir)

Source from the content-addressed store, hash-verified

670 weights.pop()
671
672 def load_model_hook(models, input_dir):
673 if args.use_ema:
674 load_model = EMAModel.from_pretrained(
675 os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
676 )
677 ema_unet.load_state_dict(load_model.state_dict())
678 if args.offload_ema:
679 ema_unet.pin_memory()
680 else:
681 ema_unet.to(accelerator.device)
682 del load_model
683
684 for _ in range(len(models)):
685 # pop models so that they are not loaded again
686 model = models.pop()
687
688 # load diffusers style into model
689 load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
690 model.register_to_config(**load_model.config)
691
692 model.load_state_dict(load_model.state_dict())
693 del load_model
694
695 accelerator.register_save_state_pre_hook(save_model_hook)
696 accelerator.register_load_state_pre_hook(load_model_hook)

Callers

nothing calls this directly

Calls 7

pin_memoryMethod · 0.80
register_to_configMethod · 0.80
from_pretrainedMethod · 0.45
load_state_dictMethod · 0.45
state_dictMethod · 0.45
toMethod · 0.45
popMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…