(self, batch, split="train")
| 319 | |
| 320 | @torch.no_grad() |
| 321 | def val_vae_forward(self, batch, split="train"): |
| 322 | # Detach batch |
| 323 | feats_ref = batch["motion"] |
| 324 | lengths = batch["length"] |
| 325 | |
| 326 | # Repeat for multimodal evaluation |
| 327 | if self.trainer.datamodule.is_mm: |
| 328 | feats_ref = feats_ref.repeat_interleave( |
| 329 | self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) |
| 330 | lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS |
| 331 | |
| 332 | # Motion encode & decode |
| 333 | feats_rst = torch.zeros_like(feats_ref) |
| 334 | |
| 335 | for i in range(len(feats_ref)): |
| 336 | if lengths[i] == 0: |
| 337 | continue |
| 338 | feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]]) |
| 339 | feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred |
| 340 | |
| 341 | code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]]) |
| 342 | |
| 343 | # codeFre_pred = torch.bincount(code_pred[0], |
| 344 | # minlength=self.hparams.codebook_size).to( |
| 345 | # self.codeFrequency.device) |
| 346 | # self.codePred.append(code_pred[0]) |
| 347 | # self.codeFrequency += codeFre_pred |
| 348 | |
| 349 | # np.save('../memData/results/codeFrequency.npy', |
| 350 | # self.codeFrequency.cpu().numpy()) |
| 351 | |
| 352 | # Recover joints for evaluation |
| 353 | joints_ref = self.feats2joints(feats_ref) |
| 354 | joints_rst = self.feats2joints(feats_rst) |
| 355 | |
| 356 | # Renorm for evaluation |
| 357 | feats_ref = self.datamodule.renorm4t2m(feats_ref) |
| 358 | feats_rst = self.datamodule.renorm4t2m(feats_rst) |
| 359 | |
| 360 | # Return set |
| 361 | rs_set = { |
| 362 | "m_ref": feats_ref, |
| 363 | "joints_ref": joints_ref, |
| 364 | "m_rst": feats_rst, |
| 365 | "joints_rst": joints_rst, |
| 366 | "length": lengths, |
| 367 | } |
| 368 | |
| 369 | return rs_set |
| 370 | |
| 371 | |
| 372 | def allsplit_step(self, split: str, batch, batch_idx): |
no test coverage detected