MCPcopy
hub / github.com/huggingface/diffusers / main

Function main

examples/custom_diffusion/train_custom_diffusion.py:663–1377  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

661
662
663def main(args):
664 if args.report_to == "wandb" and args.hub_token is not None:
665 raise ValueError(
666 "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
667 " Please use `hf auth login` to authenticate with the Hub."
668 )
669
670 logging_dir = Path(args.output_dir, args.logging_dir)
671
672 accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
673
674 accelerator = Accelerator(
675 gradient_accumulation_steps=args.gradient_accumulation_steps,
676 mixed_precision=args.mixed_precision,
677 log_with=args.report_to,
678 project_config=accelerator_project_config,
679 )
680
681 # Disable AMP for MPS.
682 if torch.backends.mps.is_available():
683 accelerator.native_amp = False
684
685 if args.report_to == "wandb":
686 if not is_wandb_available():
687 raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
688 import wandb
689
690 # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
691 # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
692 # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
693 # Make one log on every process with the configuration for debugging.
694 logging.basicConfig(
695 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
696 datefmt="%m/%d/%Y %H:%M:%S",
697 level=logging.INFO,
698 )
699 logger.info(accelerator.state, main_process_only=False)
700 if accelerator.is_local_main_process:
701 transformers.utils.logging.set_verbosity_warning()
702 diffusers.utils.logging.set_verbosity_info()
703 else:
704 transformers.utils.logging.set_verbosity_error()
705 diffusers.utils.logging.set_verbosity_error()
706
707 # We need to initialize the trackers we use, and also store our configuration.
708 # The trackers initializes automatically on the main process.
709 if accelerator.is_main_process:
710 accelerator.init_trackers("custom-diffusion", config=vars(args))
711
712 # If passed along, set the training seed now.
713 if args.seed is not None:
714 set_seed(args.seed)
715 if args.concepts_list is None:
716 args.concepts_list = [
717 {
718 "instance_prompt": args.instance_prompt,
719 "class_prompt": args.class_prompt,
720 "instance_data_dir": args.instance_data_dir,

Callers 1

Calls 15

is_wandb_availableFunction · 0.90
set_seedFunction · 0.90
is_xformers_availableFunction · 0.90
AttnProcsLayersClass · 0.90
get_schedulerFunction · 0.90
text_encoderFunction · 0.85
unetFunction · 0.85
save_new_embedFunction · 0.85
infoMethod · 0.80
existsMethod · 0.80
saveMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…