()
| 71 | |
| 72 | |
| 73 | def main(): |
| 74 | ######################### |
| 75 | # SETUP Accelerator # |
| 76 | ######################### |
| 77 | config = get_config() |
| 78 | |
| 79 | # Enable TF32 on Ampere GPUs |
| 80 | if config.training.enable_tf32: |
| 81 | torch.backends.cuda.matmul.allow_tf32 = True |
| 82 | torch.backends.cudnn.benchmark = True |
| 83 | torch.backends.cudnn.deterministic = False |
| 84 | |
| 85 | config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") |
| 86 | accelerator = Accelerator( |
| 87 | gradient_accumulation_steps=config.training.gradient_accumulation_steps, |
| 88 | mixed_precision=config.training.mixed_precision, |
| 89 | log_with="wandb", |
| 90 | project_dir=config.experiment.logging_dir, |
| 91 | split_batches=True, |
| 92 | ) |
| 93 | |
| 94 | total_batch_size_per_gpu = (config.training.batch_size_t2i |
| 95 | + config.training.batch_size_lm |
| 96 | + config.training.batch_size_mmu) |
| 97 | total_batch_size = ( |
| 98 | (config.training.batch_size_t2i + config.training.batch_size_lm + config.training.batch_size_mmu) |
| 99 | * accelerator.num_processes * config.training.gradient_accumulation_steps |
| 100 | ) |
| 101 | |
| 102 | if accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 103 | accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( |
| 104 | total_batch_size_per_gpu |
| 105 | ) |
| 106 | |
| 107 | ##################################### |
| 108 | # SETUP LOGGING, SEED and CONFIG # |
| 109 | ##################################### |
| 110 | # Make one log on every process with the configuration for debugging. |
| 111 | logging.basicConfig( |
| 112 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 113 | datefmt="%m/%d/%Y %H:%M:%S", |
| 114 | level=logging.INFO, |
| 115 | ) |
| 116 | logger.info(accelerator.state, main_process_only=False) |
| 117 | if accelerator.is_local_main_process: |
| 118 | set_verbosity_info() |
| 119 | else: |
| 120 | set_verbosity_error() |
| 121 | |
| 122 | # We need to initialize the trackers we use, and also store our configuration. |
| 123 | # The trackers initializes automatically on the main process. |
| 124 | if accelerator.is_main_process: |
| 125 | resume_wandb_run = config.wandb.resume |
| 126 | run_id = config.wandb.get("run_id", None) |
| 127 | if run_id is None: |
| 128 | resume_wandb_run = False |
| 129 | run_id = wandb.util.generate_id() |
| 130 | config.wandb.run_id = run_id |
no test coverage detected