MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / val_vae_forward

Method val_vae_forward

mGPT/models/mgpt.py:321–369  ·  view source on GitHub ↗
(self, batch, split="train")

Source from the content-addressed store, hash-verified

319
320 @torch.no_grad()
321 def val_vae_forward(self, batch, split="train"):
322 # Detach batch
323 feats_ref = batch["motion"]
324 lengths = batch["length"]
325
326 # Repeat for multimodal evaluation
327 if self.trainer.datamodule.is_mm:
328 feats_ref = feats_ref.repeat_interleave(
329 self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0)
330 lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS
331
332 # Motion encode & decode
333 feats_rst = torch.zeros_like(feats_ref)
334
335 for i in range(len(feats_ref)):
336 if lengths[i] == 0:
337 continue
338 feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]])
339 feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred
340
341 code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]])
342
343 # codeFre_pred = torch.bincount(code_pred[0],
344 # minlength=self.hparams.codebook_size).to(
345 # self.codeFrequency.device)
346 # self.codePred.append(code_pred[0])
347 # self.codeFrequency += codeFre_pred
348
349 # np.save('../memData/results/codeFrequency.npy',
350 # self.codeFrequency.cpu().numpy())
351
352 # Recover joints for evaluation
353 joints_ref = self.feats2joints(feats_ref)
354 joints_rst = self.feats2joints(feats_rst)
355
356 # Renorm for evaluation
357 feats_ref = self.datamodule.renorm4t2m(feats_ref)
358 feats_rst = self.datamodule.renorm4t2m(feats_rst)
359
360 # Return set
361 rs_set = {
362 "m_ref": feats_ref,
363 "joints_ref": joints_ref,
364 "m_rst": feats_rst,
365 "joints_rst": joints_rst,
366 "length": lengths,
367 }
368
369 return rs_set
370
371
372 def allsplit_step(self, split: str, batch, batch_idx):

Callers 1

allsplit_stepMethod · 0.95

Calls 3

encodeMethod · 0.80
feats2jointsMethod · 0.45
renorm4t2mMethod · 0.45

Tested by

no test coverage detected