MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / main

Function main

test.py:40–138  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

38
39
40def main():
41 # parse options
42 cfg = parse_args(phase="test") # parse config file
43 cfg.FOLDER = cfg.TEST.FOLDER
44
45 # Logger
46 logger = create_logger(cfg, phase="test")
47 logger.info(OmegaConf.to_yaml(cfg))
48
49 # Output dir
50 model_name = cfg.model.target.split('.')[-2].lower()
51 output_dir = Path(
52 os.path.join(cfg.FOLDER, model_name, cfg.NAME, "samples_" + cfg.TIME))
53 if cfg.TEST.SAVE_PREDICTIONS:
54 output_dir.mkdir(parents=True, exist_ok=True)
55 logger.info(f"Saving predictions to {str(output_dir)}")
56
57 # Seed
58 pl.seed_everything(cfg.SEED_VALUE)
59
60 # Environment Variables
61 os.environ["TOKENIZERS_PARALLELISM"] = "false"
62
63 # Callbacks
64 callbacks = build_callbacks(cfg, logger=logger, phase="test")
65 logger.info("Callbacks initialized")
66
67 # Dataset
68 datamodule = build_data(cfg)
69 logger.info("datasets module {} initialized".format("".join(
70 cfg.DATASET.target.split('.')[-2])))
71
72 # Model
73 model = build_model(cfg, datamodule)
74 logger.info("model {} loaded".format(cfg.model.target))
75
76 # Lightning Trainer
77 trainer = pl.Trainer(
78 benchmark=False,
79 max_epochs=cfg.TRAIN.END_EPOCH,
80 accelerator=cfg.ACCELERATOR,
81 devices=list(range(len(cfg.DEVICE))),
82 default_root_dir=cfg.FOLDER_EXP,
83 reload_dataloaders_every_n_epochs=1,
84 deterministic=False,
85 detect_anomaly=False,
86 enable_progress_bar=True,
87 logger=None,
88 callbacks=callbacks,
89 )
90
91 # Strict load vae model
92 if cfg.TRAIN.PRETRAINED_VAE:
93 load_pretrained_vae(cfg, model, logger)
94
95 # loading state dict
96 if cfg.TEST.CHECKPOINTS:
97 load_pretrained(cfg, model, logger, phase="test")

Callers 1

test.pyFile · 0.70

Calls 12

parse_argsFunction · 0.90
create_loggerFunction · 0.90
build_callbacksFunction · 0.90
build_dataFunction · 0.90
build_modelFunction · 0.90
load_pretrained_vaeFunction · 0.90
load_pretrainedFunction · 0.90
get_metric_statisticsFunction · 0.85
print_tableFunction · 0.85
itemsMethod · 0.80
mm_modeMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected