MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / main

Function main

train.py:13–90  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

11from mGPT.utils.load_checkpoint import load_pretrained, load_pretrained_vae
12
13def main():
14 # Configs
15 cfg = parse_args(phase="train") # parse config file
16
17 # Logger
18 logger = create_logger(cfg, phase="train") # create logger
19 logger.info(OmegaConf.to_yaml(cfg)) # print config file
20
21 # Seed
22 pl.seed_everything(cfg.SEED_VALUE)
23
24 # Environment Variables
25 os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
27 # Metric Logger
28 pl_loggers = []
29 for loggerName in cfg.LOGGER.TYPE:
30 if loggerName == 'tenosrboard' or cfg.LOGGER.WANDB.params.project:
31 pl_logger = instantiate_from_config(
32 eval(f'cfg.LOGGER.{loggerName.upper()}'))
33 pl_loggers.append(pl_logger)
34
35 # Callbacks
36 callbacks = build_callbacks(cfg, logger=logger, phase='train')
37 logger.info("Callbacks initialized")
38
39 # Dataset
40 datamodule = build_data(cfg)
41 logger.info("datasets module {} initialized".format("".join(
42 cfg.DATASET.target.split('.')[-2])))
43
44 # Model
45 model = build_model(cfg, datamodule)
46 logger.info("model {} loaded".format(cfg.model.target))
47
48 # Lightning Trainer
49 trainer = pl.Trainer(
50 default_root_dir=cfg.FOLDER_EXP,
51 max_epochs=cfg.TRAIN.END_EPOCH,
52 # precision='16',
53 logger=pl_loggers,
54 callbacks=callbacks,
55 check_val_every_n_epoch=cfg.LOGGER.VAL_EVERY_STEPS,
56 accelerator=cfg.ACCELERATOR,
57 devices=cfg.DEVICE,
58 num_nodes=cfg.NUM_NODES,
59 strategy="ddp_find_unused_parameters_true"
60 if len(cfg.DEVICE) > 1 else 'auto',
61 benchmark=False,
62 deterministic=False,
63 )
64 logger.info("Trainer initialized")
65
66 # Strict load pretrianed model
67 if cfg.TRAIN.PRETRAINED:
68 load_pretrained(cfg, model, logger)
69
70 # Strict load vae model

Callers 1

train.pyFile · 0.70

Calls 8

parse_argsFunction · 0.90
create_loggerFunction · 0.90
instantiate_from_configFunction · 0.90
build_callbacksFunction · 0.90
build_dataFunction · 0.90
build_modelFunction · 0.90
load_pretrainedFunction · 0.90
load_pretrained_vaeFunction · 0.90

Tested by

no test coverage detected