MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / save_npy

Method save_npy

mGPT/models/base.py:144–204  ·  view source on GitHub ↗
(self, outputs)

Source from the content-addressed store, hash-verified

142 self.metrics = BaseMetrics(datamodule=self.datamodule, **self.hparams)
143
144 def save_npy(self, outputs):
145 cfg = self.hparams.cfg
146 output_dir = Path(
147 os.path.join(
148 cfg.FOLDER,
149 str(cfg.model.target.split('.')[-2].lower()),
150 str(cfg.NAME),
151 "samples_" + cfg.TIME,
152 ))
153 if cfg.TEST.SAVE_PREDICTIONS:
154 lengths = [i[1] for i in outputs]
155 outputs = [i[0] for i in outputs]
156
157 if cfg.TEST.DATASETS[0].lower() in ["humanml3d", "kit"]:
158 keyids = self.trainer.datamodule.test_dataset.name_list
159 for i in range(len(outputs)):
160 for bid in range(
161 min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])):
162 keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid]
163 data = self.trainer.datamodule.test_dataset.data_dict[
164 keyid]
165
166 motion = torch.tensor(data['motion'],
167 device=outputs[i].device)
168 motion = self.datamodule.normalize(motion)
169 length = data['length']
170 text_list = data['text']
171 gen_joints = outputs[i][bid][:lengths[i][bid]].cpu(
172 ).numpy()
173 if cfg.TEST.REPLICATION_TIMES > 1:
174 name = f"{keyid}.npy"
175 else:
176 name = f"{keyid}.npy"
177 # save predictions results
178 npypath = output_dir / name
179 np.save(npypath, gen_joints)
180 npypath = output_dir / f"{keyid}_gt.npy"
181 joints = self.feats2joints(motion).cpu().numpy()
182 np.save(npypath, joints)
183
184 with open(output_dir / f"{keyid}.txt", "a") as f:
185 for text in text_list:
186 f.write(f"{text['caption']}\n")
187
188 elif cfg.TEST.DATASETS[0].lower() in ["humanact12", "uestc"]:
189 keyids = range(len(self.trainer.datamodule.test_dataset))
190 for i in range(len(outputs)):
191 for bid in range(
192 min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])):
193 keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid]
194 gen_joints = outputs[i][bid].cpu()
195 gen_joints = gen_joints.permute(2, 0,
196 1)[:lengths[i][bid],
197 ...].numpy()
198 if cfg.TEST.REPLICATION_TIMES > 1:
199 name = f"{keyid}_{self.rep_i}"
200 else:
201 name = f"{keyid}.npy"

Callers 1

on_test_epoch_endMethod · 0.95

Calls 3

saveMethod · 0.80
normalizeMethod · 0.45
feats2jointsMethod · 0.45

Tested by

no test coverage detected