()
| 54 | |
| 55 | |
| 56 | def main(): |
| 57 | ######################### |
| 58 | # SETUP Accelerator # |
| 59 | ######################### |
| 60 | config = get_config() |
| 61 | |
| 62 | # Enable TF32 on Ampere GPUs |
| 63 | if config.training.enable_tf32: |
| 64 | torch.backends.cuda.matmul.allow_tf32 = True |
| 65 | torch.backends.cudnn.benchmark = True |
| 66 | torch.backends.cudnn.deterministic = False |
| 67 | |
| 68 | config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") |
| 69 | accelerator = Accelerator( |
| 70 | gradient_accumulation_steps=config.training.gradient_accumulation_steps, |
| 71 | mixed_precision=config.training.mixed_precision, |
| 72 | log_with="wandb", |
| 73 | project_dir=config.experiment.logging_dir, |
| 74 | split_batches=True, |
| 75 | ) |
| 76 | |
| 77 | bs_mixed_modal = config.training.batch_size_mixed_modal |
| 78 | |
| 79 | if "concat" in config.dataset.mixed_loader_mode: |
| 80 | raise NotImplementedError |
| 81 | else: |
| 82 | total_batch_size_per_gpu = bs_mixed_modal * config.dataset.accumulation |
| 83 | total_batch_size_without_accum = total_batch_size_per_gpu * accelerator.num_processes |
| 84 | total_batch_size = total_batch_size_without_accum * config.training.gradient_accumulation_steps |
| 85 | |
| 86 | if accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 87 | accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( |
| 88 | total_batch_size_per_gpu |
| 89 | ) |
| 90 | |
| 91 | ##################################### |
| 92 | # SETUP LOGGING, SEED and CONFIG # |
| 93 | ##################################### |
| 94 | # Make one log on every process with the configuration for debugging. |
| 95 | logging.basicConfig( |
| 96 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 97 | datefmt="%m/%d/%Y %H:%M:%S", |
| 98 | level=logging.INFO, |
| 99 | ) |
| 100 | logger.info(accelerator.state, main_process_only=False) |
| 101 | if accelerator.is_local_main_process: |
| 102 | set_verbosity_info() |
| 103 | else: |
| 104 | set_verbosity_error() |
| 105 | |
| 106 | # We need to initialize the trackers we use, and also store our configuration. |
| 107 | # The trackers initializes automatically on the main process. |
| 108 | if accelerator.is_main_process: |
| 109 | resume_wandb_run = config.wandb.resume |
| 110 | run_id = config.wandb.get("run_id", None) |
| 111 | if run_id is None: |
| 112 | resume_wandb_run = False |
| 113 | run_id = wandb.util.generate_id() |
no test coverage detected