()
| 526 | |
| 527 | |
| 528 | def main(): |
| 529 | args = parse_args() |
| 530 | |
| 531 | if args.report_to == "wandb" and args.hub_token is not None: |
| 532 | raise ValueError( |
| 533 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." |
| 534 | " Please use `hf auth login` to authenticate with the Hub." |
| 535 | ) |
| 536 | |
| 537 | if args.non_ema_revision is not None: |
| 538 | deprecate( |
| 539 | "non_ema_revision!=None", |
| 540 | "0.15.0", |
| 541 | message=( |
| 542 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" |
| 543 | " use `--variant=non_ema` instead." |
| 544 | ), |
| 545 | ) |
| 546 | logging_dir = os.path.join(args.output_dir, args.logging_dir) |
| 547 | |
| 548 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| 549 | |
| 550 | accelerator = Accelerator( |
| 551 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 552 | mixed_precision=args.mixed_precision, |
| 553 | log_with=args.report_to, |
| 554 | project_config=accelerator_project_config, |
| 555 | ) |
| 556 | |
| 557 | # Disable AMP for MPS. |
| 558 | if torch.backends.mps.is_available(): |
| 559 | accelerator.native_amp = False |
| 560 | |
| 561 | # Make one log on every process with the configuration for debugging. |
| 562 | logging.basicConfig( |
| 563 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 564 | datefmt="%m/%d/%Y %H:%M:%S", |
| 565 | level=logging.INFO, |
| 566 | ) |
| 567 | logger.info(accelerator.state, main_process_only=False) |
| 568 | if accelerator.is_local_main_process: |
| 569 | datasets.utils.logging.set_verbosity_warning() |
| 570 | transformers.utils.logging.set_verbosity_warning() |
| 571 | diffusers.utils.logging.set_verbosity_info() |
| 572 | else: |
| 573 | datasets.utils.logging.set_verbosity_error() |
| 574 | transformers.utils.logging.set_verbosity_error() |
| 575 | diffusers.utils.logging.set_verbosity_error() |
| 576 | |
| 577 | # If passed along, set the training seed now. |
| 578 | if args.seed is not None: |
| 579 | set_seed(args.seed) |
| 580 | |
| 581 | # Handle the repository creation |
| 582 | if accelerator.is_main_process: |
| 583 | if args.output_dir is not None: |
| 584 | os.makedirs(args.output_dir, exist_ok=True) |
| 585 |
no test coverage detected
searching dependent graphs…