(args)
| 164 | |
| 165 | |
| 166 | def main(args): |
| 167 | seed = args.global_seed |
| 168 | torch.manual_seed(seed) |
| 169 | |
| 170 | # Determine if the current process is the main process (rank 0) |
| 171 | is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0) |
| 172 | # Setup an experiment folder and logger only if main process |
| 173 | if is_main_process: |
| 174 | experiment_dir, checkpoint_dir = create_experiment_directory(args) |
| 175 | logger = create_logger(experiment_dir) |
| 176 | OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml")) |
| 177 | logger.info(f"Experiment directory created at {experiment_dir}") |
| 178 | else: |
| 179 | experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path") |
| 180 | checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path") |
| 181 | logger = logging.getLogger(__name__) |
| 182 | logger.addHandler(logging.NullHandler()) |
| 183 | tb_logger = TensorBoardLogger(experiment_dir, name="latte") |
| 184 | |
| 185 | # Create the dataset and dataloader |
| 186 | dataset = get_dataset(args) |
| 187 | loader = DataLoader( |
| 188 | dataset, |
| 189 | batch_size=args.local_batch_size, |
| 190 | shuffle=True, |
| 191 | num_workers=args.num_workers, |
| 192 | pin_memory=True, |
| 193 | drop_last=True |
| 194 | ) |
| 195 | if is_main_process: |
| 196 | logger.info(f"Dataset contains {len(dataset)} videos ({args.data_path})") |
| 197 | |
| 198 | sample_size = args.image_size // 8 |
| 199 | args.latent_size = sample_size |
| 200 | |
| 201 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 202 | num_update_steps_per_epoch = math.ceil(len(loader)) |
| 203 | # Afterwards we recalculate our number of training epochs |
| 204 | num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 205 | # In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers |
| 206 | if is_main_process: |
| 207 | logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps") |
| 208 | logger.info(f"Num train epochs: {num_train_epochs}") |
| 209 | |
| 210 | # Initialize the training module |
| 211 | pl_module = LatteTrainingModule(args, logger) |
| 212 | |
| 213 | checkpoint_callback = ModelCheckpoint( |
| 214 | dirpath=checkpoint_dir, |
| 215 | filename="{epoch}-{step}-{train_loss:.2f}-{gradient_norm:.2f}", |
| 216 | save_top_k=-1, |
| 217 | every_n_train_steps=args.ckpt_every, |
| 218 | save_on_train_epoch_end=True, # Optional |
| 219 | ) |
| 220 | |
| 221 | # Trainer |
| 222 | trainer = Trainer( |
| 223 | accelerator="gpu", |
no test coverage detected