MCPcopy
hub / github.com/Vchitect/Latte / main

Function main

train_pl.py:166–238  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

164
165
166def main(args):
167 seed = args.global_seed
168 torch.manual_seed(seed)
169
170 # Determine if the current process is the main process (rank 0)
171 is_main_process = (int(os.environ.get("LOCAL_RANK", 0)) == 0)
172 # Setup an experiment folder and logger only if main process
173 if is_main_process:
174 experiment_dir, checkpoint_dir = create_experiment_directory(args)
175 logger = create_logger(experiment_dir)
176 OmegaConf.save(args, os.path.join(experiment_dir, "config.yaml"))
177 logger.info(f"Experiment directory created at {experiment_dir}")
178 else:
179 experiment_dir = os.getenv("EXPERIMENT_DIR", "default_path")
180 checkpoint_dir = os.getenv("CHECKPOINT_DIR", "default_path")
181 logger = logging.getLogger(__name__)
182 logger.addHandler(logging.NullHandler())
183 tb_logger = TensorBoardLogger(experiment_dir, name="latte")
184
185 # Create the dataset and dataloader
186 dataset = get_dataset(args)
187 loader = DataLoader(
188 dataset,
189 batch_size=args.local_batch_size,
190 shuffle=True,
191 num_workers=args.num_workers,
192 pin_memory=True,
193 drop_last=True
194 )
195 if is_main_process:
196 logger.info(f"Dataset contains {len(dataset)} videos ({args.data_path})")
197
198 sample_size = args.image_size // 8
199 args.latent_size = sample_size
200
201 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
202 num_update_steps_per_epoch = math.ceil(len(loader))
203 # Afterwards we recalculate our number of training epochs
204 num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
205 # In multi GPUs mode, the real batchsize is local_batch_size * GPU numbers
206 if is_main_process:
207 logger.info(f"One epoch iteration {num_update_steps_per_epoch} steps")
208 logger.info(f"Num train epochs: {num_train_epochs}")
209
210 # Initialize the training module
211 pl_module = LatteTrainingModule(args, logger)
212
213 checkpoint_callback = ModelCheckpoint(
214 dirpath=checkpoint_dir,
215 filename="{epoch}-{step}-{train_loss:.2f}-{gradient_norm:.2f}",
216 save_top_k=-1,
217 every_n_train_steps=args.ckpt_every,
218 save_on_train_epoch_end=True, # Optional
219 )
220
221 # Trainer
222 trainer = Trainer(
223 accelerator="gpu",

Callers 1

train_pl.pyFile · 0.70

Calls 6

get_datasetFunction · 0.90
cleanupFunction · 0.90
saveMethod · 0.80
create_loggerFunction · 0.70
LatteTrainingModuleClass · 0.70

Tested by

no test coverage detected