MCPcopy
hub / github.com/kohya-ss/sd-scripts / save_model

Function save_model

sdxl_train_control_net.py:423–446  ·  view source on GitHub ↗
(ckpt_name, model, force_sync_upload=False)

Source from the content-addressed store, hash-verified

421
422 # function for saving/removing
423 def save_model(ckpt_name, model, force_sync_upload=False):
424 os.makedirs(args.output_dir, exist_ok=True)
425 ckpt_file = os.path.join(args.output_dir, ckpt_name)
426
427 accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
428 sai_metadata = model_io.get_sai_model_spec(None, args, True, True, False)
429 sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet"
430 state_dict = model.state_dict()
431
432 if save_dtype is not None:
433 for key in list(state_dict.keys()):
434 v = state_dict[key]
435 v = v.detach().clone().to("cpu").to(save_dtype)
436 state_dict[key] = v
437
438 if os.path.splitext(ckpt_file)[1] == ".safetensors":
439 from safetensors.torch import save_file
440
441 save_file(state_dict, ckpt_file, sai_metadata)
442 else:
443 torch.save(state_dict, ckpt_file)
444
445 if args.huggingface_repo_id is not None:
446 huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
447
448 def remove_model(old_ckpt_name):
449 old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)

Callers 1

trainFunction · 0.70

Calls 4

toMethod · 0.80
get_sai_model_specMethod · 0.45
state_dictMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected