(args)
| 592 | |
| 593 | |
| 594 | def main(args): |
| 595 | logging_dir = Path(args.output_dir, args.logging_dir) |
| 596 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| 597 | |
| 598 | accelerator = Accelerator( |
| 599 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 600 | mixed_precision=args.mixed_precision, |
| 601 | log_with=args.report_to, |
| 602 | logging_dir=logging_dir, |
| 603 | ) |
| 604 | |
| 605 | if args.report_to == "wandb": |
| 606 | if not is_wandb_available(): |
| 607 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| 608 | import wandb |
| 609 | |
| 610 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate |
| 611 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. |
| 612 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. |
| 613 | # Make one log on every process with the configuration for debugging. |
| 614 | logging.basicConfig( |
| 615 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 616 | datefmt="%m/%d/%Y %H:%M:%S", |
| 617 | level=logging.INFO, |
| 618 | ) |
| 619 | logger.info(accelerator.state, main_process_only=False) |
| 620 | if accelerator.is_local_main_process: |
| 621 | transformers.utils.logging.set_verbosity_warning() |
| 622 | diffusers.utils.logging.set_verbosity_info() |
| 623 | else: |
| 624 | transformers.utils.logging.set_verbosity_error() |
| 625 | diffusers.utils.logging.set_verbosity_error() |
| 626 | |
| 627 | if args.seed is not None: |
| 628 | set_seed(args.seed) |
| 629 | if args.concepts_list is None: |
| 630 | args.concepts_list = [ |
| 631 | { |
| 632 | "instance_prompt": args.instance_prompt, |
| 633 | "class_prompt": args.class_prompt, |
| 634 | "instance_data_dir": args.instance_data_dir, |
| 635 | "class_data_dir": args.class_data_dir |
| 636 | } |
| 637 | ] |
| 638 | else: |
| 639 | with open(args.concepts_list, "r") as f: |
| 640 | args.concepts_list = json.load(f) |
| 641 | |
| 642 | if args.with_prior_preservation: |
| 643 | for i, concept in enumerate(args.concepts_list): |
| 644 | class_images_dir = Path(concept['class_data_dir']) |
| 645 | if not class_images_dir.exists(): |
| 646 | class_images_dir.mkdir(parents=True, exist_ok=True) |
| 647 | if args.real_prior: |
| 648 | if accelerator.is_main_process: |
| 649 | if not Path(os.path.join(class_images_dir, 'images')).exists() or len(list(Path(os.path.join(class_images_dir, 'images')).iterdir())) < args.num_class_images: |
| 650 | retrieve.retrieve(concept['class_prompt'], class_images_dir, args.num_class_images) |
| 651 | concept['class_prompt'] = os.path.join(class_images_dir, 'caption.txt') |
no test coverage detected