MCPcopy Index your code
hub / github.com/adobe-research/custom-diffusion / main

Function main

src/diffusers_training.py:594–1102  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

592
593
594def main(args):
595 logging_dir = Path(args.output_dir, args.logging_dir)
596 accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
597
598 accelerator = Accelerator(
599 gradient_accumulation_steps=args.gradient_accumulation_steps,
600 mixed_precision=args.mixed_precision,
601 log_with=args.report_to,
602 logging_dir=logging_dir,
603 )
604
605 if args.report_to == "wandb":
606 if not is_wandb_available():
607 raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
608 import wandb
609
610 # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
611 # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
612 # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
613 # Make one log on every process with the configuration for debugging.
614 logging.basicConfig(
615 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
616 datefmt="%m/%d/%Y %H:%M:%S",
617 level=logging.INFO,
618 )
619 logger.info(accelerator.state, main_process_only=False)
620 if accelerator.is_local_main_process:
621 transformers.utils.logging.set_verbosity_warning()
622 diffusers.utils.logging.set_verbosity_info()
623 else:
624 transformers.utils.logging.set_verbosity_error()
625 diffusers.utils.logging.set_verbosity_error()
626
627 if args.seed is not None:
628 set_seed(args.seed)
629 if args.concepts_list is None:
630 args.concepts_list = [
631 {
632 "instance_prompt": args.instance_prompt,
633 "class_prompt": args.class_prompt,
634 "instance_data_dir": args.instance_data_dir,
635 "class_data_dir": args.class_data_dir
636 }
637 ]
638 else:
639 with open(args.concepts_list, "r") as f:
640 args.concepts_list = json.load(f)
641
642 if args.with_prior_preservation:
643 for i, concept in enumerate(args.concepts_list):
644 class_images_dir = Path(concept['class_data_dir'])
645 if not class_images_dir.exists():
646 class_images_dir.mkdir(parents=True, exist_ok=True)
647 if args.real_prior:
648 if accelerator.is_main_process:
649 if not Path(os.path.join(class_images_dir, 'images')).exists() or len(list(Path(os.path.join(class_images_dir, 'images')).iterdir())) < args.num_class_images:
650 retrieve.retrieve(concept['class_prompt'], class_images_dir, args.num_class_images)
651 concept['class_prompt'] = os.path.join(class_images_dir, 'caption.txt')

Callers 1

Calls 9

PromptDatasetClass · 0.90
collate_fnFunction · 0.90
get_full_repo_nameFunction · 0.85
create_custom_diffusionFunction · 0.70
freeze_paramsFunction · 0.70
encodeMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected