Trains a new model.
(args)
| 34 | ################################################################################# |
| 35 | |
| 36 | def main(args): |
| 37 | """ |
| 38 | Trains a new model. |
| 39 | """ |
| 40 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| 41 | |
| 42 | # Setup DDP: |
| 43 | init_distributed_mode(args) |
| 44 | assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size." |
| 45 | rank = dist.get_rank() |
| 46 | device = rank % torch.cuda.device_count() |
| 47 | seed = args.global_seed * dist.get_world_size() + rank |
| 48 | torch.manual_seed(seed) |
| 49 | torch.cuda.set_device(device) |
| 50 | |
| 51 | # Setup an experiment folder: |
| 52 | if rank == 0: |
| 53 | os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders) |
| 54 | experiment_index = len(glob(f"{args.results_dir}/*")) |
| 55 | model_string_name = args.vq_model.replace("/", "-") |
| 56 | experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder |
| 57 | checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints |
| 58 | os.makedirs(checkpoint_dir, exist_ok=True) |
| 59 | logger = create_logger(experiment_dir) |
| 60 | logger.info(f"Experiment directory created at {experiment_dir}") |
| 61 | |
| 62 | time_record = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) |
| 63 | cloud_results_dir = f"{args.cloud_save_path}/{time_record}" |
| 64 | cloud_checkpoint_dir = f"{cloud_results_dir}/{experiment_index:03d}-{model_string_name}/checkpoints" |
| 65 | os.makedirs(cloud_checkpoint_dir, exist_ok=True) |
| 66 | logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}") |
| 67 | |
| 68 | else: |
| 69 | logger = create_logger(None) |
| 70 | |
| 71 | # training args |
| 72 | logger.info(f"{args}") |
| 73 | |
| 74 | # training env |
| 75 | logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| 76 | |
| 77 | # create and load model |
| 78 | vq_model = VQ_models[args.vq_model]( |
| 79 | codebook_size=args.codebook_size, |
| 80 | codebook_embed_dim=args.codebook_embed_dim, |
| 81 | commit_loss_beta=args.commit_loss_beta, |
| 82 | entropy_loss_ratio=args.entropy_loss_ratio, |
| 83 | dropout_p=args.dropout_p, |
| 84 | ) |
| 85 | logger.info(f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}") |
| 86 | if args.ema: |
| 87 | ema = deepcopy(vq_model).to(device) # Create an EMA of the model for use after training |
| 88 | requires_grad(ema, False) |
| 89 | logger.info(f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}") |
| 90 | vq_model = vq_model.to(device) |
| 91 | |
| 92 | vq_loss = VQLoss( |
| 93 | disc_start=args.disc_start, |
no test coverage detected