MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / main

Function main

demo.py:122–236  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

120
121
122def main():
123 # parse options
124 cfg = parse_args(phase="demo") # parse config file
125 cfg.FOLDER = cfg.TEST.FOLDER
126
127 # create logger
128 logger = create_logger(cfg, phase="test")
129
130 task = cfg.DEMO.TASK
131 text = None
132
133 output_dir = Path(
134 os.path.join(cfg.FOLDER, str(cfg.model.model_type), str(cfg.NAME),
135 "samples_" + cfg.TIME))
136 output_dir.mkdir(parents=True, exist_ok=True)
137
138 logger.info(OmegaConf.to_yaml(cfg))
139
140 # set seed
141 pl.seed_everything(cfg.SEED_VALUE)
142
143 # gpu setting
144 if cfg.ACCELERATOR == "gpu":
145 os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
146 str(x) for x in cfg.DEVICE)
147 device = torch.device("cuda")
148
149 # Dataset
150 datamodule = build_data(cfg)
151 logger.info("datasets module {} initialized".format("".join(
152 cfg.DATASET.target.split('.')[-2])))
153
154 # create model
155 total_time = time.time()
156 model = build_model(cfg, datamodule)
157 logger.info("model {} loaded".format(cfg.model.target))
158
159 # loading state dict
160 if cfg.TEST.CHECKPOINTS:
161 logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
162 state_dict = torch.load(cfg.TEST.CHECKPOINTS,
163 map_location="cpu")["state_dict"]
164 model.load_state_dict(state_dict)
165 else:
166 logger.warning(
167 "No checkpoints provided, using random initialized model")
168
169 model.to(device)
170
171 if cfg.DEMO.EXAMPLE:
172 # Check txt file input
173 # load txt
174 return_dict = load_example_input(cfg.DEMO.EXAMPLE, task, model)
175 text, in_joints = return_dict['text'], return_dict['motion_joints']
176
177 batch_size = 64
178 if text:
179 for b in tqdm(range(len(text) // batch_size + 1)):

Callers 1

demo.pyFile · 0.70

Calls 10

parse_argsFunction · 0.90
create_loggerFunction · 0.90
build_dataFunction · 0.90
build_modelFunction · 0.90
deviceMethod · 0.80
load_state_dictMethod · 0.80
detachMethod · 0.80
saveMethod · 0.80
load_example_inputFunction · 0.70
toMethod · 0.45

Tested by

no test coverage detected