(args)
| 661 | |
| 662 | |
| 663 | def main(args): |
| 664 | if args.report_to == "wandb" and args.hub_token is not None: |
| 665 | raise ValueError( |
| 666 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." |
| 667 | " Please use `hf auth login` to authenticate with the Hub." |
| 668 | ) |
| 669 | |
| 670 | logging_dir = Path(args.output_dir, args.logging_dir) |
| 671 | |
| 672 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| 673 | |
| 674 | accelerator = Accelerator( |
| 675 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 676 | mixed_precision=args.mixed_precision, |
| 677 | log_with=args.report_to, |
| 678 | project_config=accelerator_project_config, |
| 679 | ) |
| 680 | |
| 681 | # Disable AMP for MPS. |
| 682 | if torch.backends.mps.is_available(): |
| 683 | accelerator.native_amp = False |
| 684 | |
| 685 | if args.report_to == "wandb": |
| 686 | if not is_wandb_available(): |
| 687 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| 688 | import wandb |
| 689 | |
| 690 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate |
| 691 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. |
| 692 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. |
| 693 | # Make one log on every process with the configuration for debugging. |
| 694 | logging.basicConfig( |
| 695 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 696 | datefmt="%m/%d/%Y %H:%M:%S", |
| 697 | level=logging.INFO, |
| 698 | ) |
| 699 | logger.info(accelerator.state, main_process_only=False) |
| 700 | if accelerator.is_local_main_process: |
| 701 | transformers.utils.logging.set_verbosity_warning() |
| 702 | diffusers.utils.logging.set_verbosity_info() |
| 703 | else: |
| 704 | transformers.utils.logging.set_verbosity_error() |
| 705 | diffusers.utils.logging.set_verbosity_error() |
| 706 | |
| 707 | # We need to initialize the trackers we use, and also store our configuration. |
| 708 | # The trackers initializes automatically on the main process. |
| 709 | if accelerator.is_main_process: |
| 710 | accelerator.init_trackers("custom-diffusion", config=vars(args)) |
| 711 | |
| 712 | # If passed along, set the training seed now. |
| 713 | if args.seed is not None: |
| 714 | set_seed(args.seed) |
| 715 | if args.concepts_list is None: |
| 716 | args.concepts_list = [ |
| 717 | { |
| 718 | "instance_prompt": args.instance_prompt, |
| 719 | "class_prompt": args.class_prompt, |
| 720 | "instance_data_dir": args.instance_data_dir, |
no test coverage detected
searching dependent graphs…