| 370 | |
| 371 | |
| 372 | def allsplit_step(self, split: str, batch, batch_idx): |
| 373 | # Compute the losses |
| 374 | loss = None |
| 375 | |
| 376 | if self.hparams.stage == "vae" and split in ["train", "val"]: |
| 377 | rs_set = self.train_vae_forward(batch) |
| 378 | loss = self._losses['losses_' + split].update(rs_set) |
| 379 | elif self.hparams.stage in ["lm_instruct", "lm_pretrain" |
| 380 | ] and split in ["train"]: |
| 381 | rs_set = self.train_lm_forward(batch) |
| 382 | loss = self._losses['losses_' + split].update(rs_set) |
| 383 | elif self.hparams.stage == 'lm_rl' and split in ['train']: |
| 384 | rs_set = self.train_rl_forward(batch) |
| 385 | loss = None |
| 386 | |
| 387 | # Compute the metrics |
| 388 | if split in ["val", "test"]: |
| 389 | if self.hparams.stage == "vae": |
| 390 | rs_set = self.val_vae_forward(batch, split) |
| 391 | elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]: |
| 392 | if self.hparams.task == "t2m": |
| 393 | rs_set = self.val_t2m_forward(batch) |
| 394 | elif self.hparams.task == "m2t": |
| 395 | rs_set = self.val_m2t_forward(batch) |
| 396 | elif self.hparams.task in ["m2m", "pred", "inbetween"]: |
| 397 | rs_set = self.val_m2m_forward(batch, self.hparams.task) |
| 398 | |
| 399 | if self.hparams.task not in ["m2t"]: |
| 400 | # MultiModality evaluation sperately |
| 401 | if self.trainer.datamodule.is_mm: |
| 402 | metrics_dicts = ['MMMetrics'] |
| 403 | else: |
| 404 | metrics_dicts = self.hparams.metrics_dict |
| 405 | |
| 406 | if self.hparams.task not in ['pred', 'inbetween'] and 'PredMetrics' in metrics_dicts: |
| 407 | metrics_dicts.remove('PredMetrics') |
| 408 | |
| 409 | for metric in metrics_dicts: |
| 410 | lengths = batch['length'] |
| 411 | if metric == "TemosMetric": |
| 412 | getattr(self.metrics, |
| 413 | metric).update(rs_set["joints_rst"], |
| 414 | rs_set["joints_ref"], lengths) |
| 415 | elif metric == "TM2TMetrics": |
| 416 | if self.hparams.stage in [ |
| 417 | "lm_instruct", "lm_pretrain", "lm_rl" |
| 418 | ]: |
| 419 | word_embs = batch['word_embs'] |
| 420 | pos_ohot = batch['pos_ohot'] |
| 421 | text_lengths = batch['text_len'] |
| 422 | if self.trainer.datamodule.is_mm: |
| 423 | word_embs = word_embs.repeat_interleave( |
| 424 | self.hparams.cfg.METRIC.MM_NUM_REPEATS, |
| 425 | dim=0) |
| 426 | pos_ohot = pos_ohot.repeat_interleave( |
| 427 | self.hparams.cfg.METRIC.MM_NUM_REPEATS, |
| 428 | dim=0) |
| 429 | text_lengths = text_lengths.repeat_interleave( |