(args)
| 1017 | |
| 1018 | |
| 1019 | def main(args): |
| 1020 | if args.report_to == "wandb" and args.hub_token is not None: |
| 1021 | raise ValueError( |
| 1022 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." |
| 1023 | " Please use `huggingface-cli login` to authenticate with the Hub." |
| 1024 | ) |
| 1025 | |
| 1026 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16": |
| 1027 | # due to pytorch#99272, MPS does not yet support bfloat16. |
| 1028 | raise ValueError( |
| 1029 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." |
| 1030 | ) |
| 1031 | |
| 1032 | logging_dir = Path(args.output_dir, args.logging_dir) |
| 1033 | |
| 1034 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| 1035 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| 1036 | init_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) |
| 1037 | accelerator = Accelerator( |
| 1038 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1039 | mixed_precision=args.mixed_precision, |
| 1040 | log_with=args.report_to, |
| 1041 | project_config=accelerator_project_config, |
| 1042 | kwargs_handlers=[ddp_kwargs, init_kwargs], |
| 1043 | ) |
| 1044 | |
| 1045 | # Disable AMP for MPS. |
| 1046 | if torch.backends.mps.is_available(): |
| 1047 | accelerator.native_amp = False |
| 1048 | |
| 1049 | if args.report_to == "wandb": |
| 1050 | if not is_wandb_available(): |
| 1051 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
| 1052 | |
| 1053 | # Make one log on every process with the configuration for debugging. |
| 1054 | logging.basicConfig( |
| 1055 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 1056 | datefmt="%m/%d/%Y %H:%M:%S", |
| 1057 | level=logging.INFO, |
| 1058 | ) |
| 1059 | logger.info(accelerator.state, main_process_only=False) |
| 1060 | if accelerator.is_local_main_process: |
| 1061 | transformers.utils.logging.set_verbosity_warning() |
| 1062 | diffusers.utils.logging.set_verbosity_info() |
| 1063 | else: |
| 1064 | transformers.utils.logging.set_verbosity_error() |
| 1065 | diffusers.utils.logging.set_verbosity_error() |
| 1066 | |
| 1067 | # If passed along, set the training seed now. |
| 1068 | if args.seed is not None: |
| 1069 | set_seed(args.seed) |
| 1070 | |
| 1071 | # Handle the repository creation |
| 1072 | if accelerator.is_main_process: |
| 1073 | if args.output_dir is not None: |
| 1074 | os.makedirs(args.output_dir, exist_ok=True) |
| 1075 | |
| 1076 | if args.push_to_hub: |
no test coverage detected