MCPcopy
hub / github.com/kohya-ss/sd-scripts / save_model

Function save_model

train_control_net.py:410–432  ·  view source on GitHub ↗
(ckpt_name, model, force_sync_upload=False)

Source from the content-addressed store, hash-verified

408
409 # function for saving/removing
410 def save_model(ckpt_name, model, force_sync_upload=False):
411 os.makedirs(args.output_dir, exist_ok=True)
412 ckpt_file = os.path.join(args.output_dir, ckpt_name)
413
414 accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
415
416 state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
417
418 if save_dtype is not None:
419 for key in list(state_dict.keys()):
420 v = state_dict[key]
421 v = v.detach().clone().to("cpu").to(save_dtype)
422 state_dict[key] = v
423
424 if os.path.splitext(ckpt_file)[1] == ".safetensors":
425 from safetensors.torch import save_file
426
427 save_file(state_dict, ckpt_file)
428 else:
429 torch.save(state_dict, ckpt_file)
430
431 if args.huggingface_repo_id is not None:
432 huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
433
434 def remove_model(old_ckpt_name):
435 old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)

Callers 2

trainMethod · 0.70
trainFunction · 0.70

Calls 3

toMethod · 0.80
state_dictMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected