(models, input_dir)
| 670 | weights.pop() |
| 671 | |
| 672 | def load_model_hook(models, input_dir): |
| 673 | if args.use_ema: |
| 674 | load_model = EMAModel.from_pretrained( |
| 675 | os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema |
| 676 | ) |
| 677 | ema_unet.load_state_dict(load_model.state_dict()) |
| 678 | if args.offload_ema: |
| 679 | ema_unet.pin_memory() |
| 680 | else: |
| 681 | ema_unet.to(accelerator.device) |
| 682 | del load_model |
| 683 | |
| 684 | for _ in range(len(models)): |
| 685 | # pop models so that they are not loaded again |
| 686 | model = models.pop() |
| 687 | |
| 688 | # load diffusers style into model |
| 689 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") |
| 690 | model.register_to_config(**load_model.config) |
| 691 | |
| 692 | model.load_state_dict(load_model.state_dict()) |
| 693 | del load_model |
| 694 | |
| 695 | accelerator.register_save_state_pre_hook(save_model_hook) |
| 696 | accelerator.register_load_state_pre_hook(load_model_hook) |
nothing calls this directly
no test coverage detected
searching dependent graphs…