(self, outputs)
| 142 | self.metrics = BaseMetrics(datamodule=self.datamodule, **self.hparams) |
| 143 | |
| 144 | def save_npy(self, outputs): |
| 145 | cfg = self.hparams.cfg |
| 146 | output_dir = Path( |
| 147 | os.path.join( |
| 148 | cfg.FOLDER, |
| 149 | str(cfg.model.target.split('.')[-2].lower()), |
| 150 | str(cfg.NAME), |
| 151 | "samples_" + cfg.TIME, |
| 152 | )) |
| 153 | if cfg.TEST.SAVE_PREDICTIONS: |
| 154 | lengths = [i[1] for i in outputs] |
| 155 | outputs = [i[0] for i in outputs] |
| 156 | |
| 157 | if cfg.TEST.DATASETS[0].lower() in ["humanml3d", "kit"]: |
| 158 | keyids = self.trainer.datamodule.test_dataset.name_list |
| 159 | for i in range(len(outputs)): |
| 160 | for bid in range( |
| 161 | min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): |
| 162 | keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] |
| 163 | data = self.trainer.datamodule.test_dataset.data_dict[ |
| 164 | keyid] |
| 165 | |
| 166 | motion = torch.tensor(data['motion'], |
| 167 | device=outputs[i].device) |
| 168 | motion = self.datamodule.normalize(motion) |
| 169 | length = data['length'] |
| 170 | text_list = data['text'] |
| 171 | gen_joints = outputs[i][bid][:lengths[i][bid]].cpu( |
| 172 | ).numpy() |
| 173 | if cfg.TEST.REPLICATION_TIMES > 1: |
| 174 | name = f"{keyid}.npy" |
| 175 | else: |
| 176 | name = f"{keyid}.npy" |
| 177 | # save predictions results |
| 178 | npypath = output_dir / name |
| 179 | np.save(npypath, gen_joints) |
| 180 | npypath = output_dir / f"{keyid}_gt.npy" |
| 181 | joints = self.feats2joints(motion).cpu().numpy() |
| 182 | np.save(npypath, joints) |
| 183 | |
| 184 | with open(output_dir / f"{keyid}.txt", "a") as f: |
| 185 | for text in text_list: |
| 186 | f.write(f"{text['caption']}\n") |
| 187 | |
| 188 | elif cfg.TEST.DATASETS[0].lower() in ["humanact12", "uestc"]: |
| 189 | keyids = range(len(self.trainer.datamodule.test_dataset)) |
| 190 | for i in range(len(outputs)): |
| 191 | for bid in range( |
| 192 | min(cfg.TEST.BATCH_SIZE, outputs[i].shape[0])): |
| 193 | keyid = keyids[i * cfg.TEST.BATCH_SIZE + bid] |
| 194 | gen_joints = outputs[i][bid].cpu() |
| 195 | gen_joints = gen_joints.permute(2, 0, |
| 196 | 1)[:lengths[i][bid], |
| 197 | ...].numpy() |
| 198 | if cfg.TEST.REPLICATION_TIMES > 1: |
| 199 | name = f"{keyid}_{self.rep_i}" |
| 200 | else: |
| 201 | name = f"{keyid}.npy" |
no test coverage detected