(ckpt_name, model, force_sync_upload=False)
| 408 | |
| 409 | # function for saving/removing |
| 410 | def save_model(ckpt_name, model, force_sync_upload=False): |
| 411 | os.makedirs(args.output_dir, exist_ok=True) |
| 412 | ckpt_file = os.path.join(args.output_dir, ckpt_name) |
| 413 | |
| 414 | accelerator.print(f"\nsaving checkpoint: {ckpt_file}") |
| 415 | |
| 416 | state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) |
| 417 | |
| 418 | if save_dtype is not None: |
| 419 | for key in list(state_dict.keys()): |
| 420 | v = state_dict[key] |
| 421 | v = v.detach().clone().to("cpu").to(save_dtype) |
| 422 | state_dict[key] = v |
| 423 | |
| 424 | if os.path.splitext(ckpt_file)[1] == ".safetensors": |
| 425 | from safetensors.torch import save_file |
| 426 | |
| 427 | save_file(state_dict, ckpt_file) |
| 428 | else: |
| 429 | torch.save(state_dict, ckpt_file) |
| 430 | |
| 431 | if args.huggingface_repo_id is not None: |
| 432 | huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) |
| 433 | |
| 434 | def remove_model(old_ckpt_name): |
| 435 | old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) |
no test coverage detected