MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / allsplit_step

Method allsplit_step

mGPT/models/mgpt.py:372–494  ·  view source on GitHub ↗
(self, split: str, batch, batch_idx)

Source from the content-addressed store, hash-verified

370
371
372 def allsplit_step(self, split: str, batch, batch_idx):
373 # Compute the losses
374 loss = None
375
376 if self.hparams.stage == "vae" and split in ["train", "val"]:
377 rs_set = self.train_vae_forward(batch)
378 loss = self._losses['losses_' + split].update(rs_set)
379 elif self.hparams.stage in ["lm_instruct", "lm_pretrain"
380 ] and split in ["train"]:
381 rs_set = self.train_lm_forward(batch)
382 loss = self._losses['losses_' + split].update(rs_set)
383 elif self.hparams.stage == 'lm_rl' and split in ['train']:
384 rs_set = self.train_rl_forward(batch)
385 loss = None
386
387 # Compute the metrics
388 if split in ["val", "test"]:
389 if self.hparams.stage == "vae":
390 rs_set = self.val_vae_forward(batch, split)
391 elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]:
392 if self.hparams.task == "t2m":
393 rs_set = self.val_t2m_forward(batch)
394 elif self.hparams.task == "m2t":
395 rs_set = self.val_m2t_forward(batch)
396 elif self.hparams.task in ["m2m", "pred", "inbetween"]:
397 rs_set = self.val_m2m_forward(batch, self.hparams.task)
398
399 if self.hparams.task not in ["m2t"]:
400 # MultiModality evaluation sperately
401 if self.trainer.datamodule.is_mm:
402 metrics_dicts = ['MMMetrics']
403 else:
404 metrics_dicts = self.hparams.metrics_dict
405
406 if self.hparams.task not in ['pred', 'inbetween'] and 'PredMetrics' in metrics_dicts:
407 metrics_dicts.remove('PredMetrics')
408
409 for metric in metrics_dicts:
410 lengths = batch['length']
411 if metric == "TemosMetric":
412 getattr(self.metrics,
413 metric).update(rs_set["joints_rst"],
414 rs_set["joints_ref"], lengths)
415 elif metric == "TM2TMetrics":
416 if self.hparams.stage in [
417 "lm_instruct", "lm_pretrain", "lm_rl"
418 ]:
419 word_embs = batch['word_embs']
420 pos_ohot = batch['pos_ohot']
421 text_lengths = batch['text_len']
422 if self.trainer.datamodule.is_mm:
423 word_embs = word_embs.repeat_interleave(
424 self.hparams.cfg.METRIC.MM_NUM_REPEATS,
425 dim=0)
426 pos_ohot = pos_ohot.repeat_interleave(
427 self.hparams.cfg.METRIC.MM_NUM_REPEATS,
428 dim=0)
429 text_lengths = text_lengths.repeat_interleave(

Callers 3

training_stepMethod · 0.80
validation_stepMethod · 0.80
test_stepMethod · 0.80

Calls 7

train_vae_forwardMethod · 0.95
train_lm_forwardMethod · 0.95
val_vae_forwardMethod · 0.95
val_t2m_forwardMethod · 0.95
val_m2t_forwardMethod · 0.95
val_m2m_forwardMethod · 0.95
updateMethod · 0.45

Tested by 1

test_stepMethod · 0.64