(ckpt_name, model, force_sync_upload=False)
| 421 | |
| 422 | # function for saving/removing |
| 423 | def save_model(ckpt_name, model, force_sync_upload=False): |
| 424 | os.makedirs(args.output_dir, exist_ok=True) |
| 425 | ckpt_file = os.path.join(args.output_dir, ckpt_name) |
| 426 | |
| 427 | accelerator.print(f"\nsaving checkpoint: {ckpt_file}") |
| 428 | sai_metadata = model_io.get_sai_model_spec(None, args, True, True, False) |
| 429 | sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet" |
| 430 | state_dict = model.state_dict() |
| 431 | |
| 432 | if save_dtype is not None: |
| 433 | for key in list(state_dict.keys()): |
| 434 | v = state_dict[key] |
| 435 | v = v.detach().clone().to("cpu").to(save_dtype) |
| 436 | state_dict[key] = v |
| 437 | |
| 438 | if os.path.splitext(ckpt_file)[1] == ".safetensors": |
| 439 | from safetensors.torch import save_file |
| 440 | |
| 441 | save_file(state_dict, ckpt_file, sai_metadata) |
| 442 | else: |
| 443 | torch.save(state_dict, ckpt_file) |
| 444 | |
| 445 | if args.huggingface_repo_id is not None: |
| 446 | huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) |
| 447 | |
| 448 | def remove_model(old_ckpt_name): |
| 449 | old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) |
no test coverage detected