(self, batch)
| 136 | |
| 137 | @torch.no_grad() |
| 138 | def val_t2m_forward(self, batch): |
| 139 | feats_ref = batch["motion"] |
| 140 | texts = batch["text"] |
| 141 | lengths = batch["length"] |
| 142 | tasks = None |
| 143 | if self.trainer.datamodule.is_mm: |
| 144 | texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS |
| 145 | feats_ref = feats_ref.repeat_interleave( |
| 146 | self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) |
| 147 | lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS |
| 148 | instructions = pjoin(self.datamodule.hparams.data_root, |
| 149 | 'template_instructions.json') |
| 150 | instructions = json.load(open(instructions, 'r')) |
| 151 | tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts) |
| 152 | |
| 153 | if self.hparams.condition == 'caption': |
| 154 | tasks = [{ |
| 155 | 'input': ['<Caption_Placeholder>'], |
| 156 | 'output': [''] |
| 157 | }] * len(texts) |
| 158 | |
| 159 | if self.hparams.cfg.DATASET.TASK_PATH: |
| 160 | instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH) |
| 161 | instructions = json.load(open(instructions, 'r')) |
| 162 | tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts) |
| 163 | |
| 164 | min_len = lengths.copy() |
| 165 | # Forward |
| 166 | outputs = self.lm.generate_conditional(texts, |
| 167 | lengths=lengths, |
| 168 | stage='test', |
| 169 | tasks=tasks) |
| 170 | |
| 171 | # Motion Decode |
| 172 | feats_rst = torch.zeros_like(feats_ref) |
| 173 | |
| 174 | for i in range(len(texts)): |
| 175 | outputs[i] = torch.clamp(outputs[i], |
| 176 | 0, |
| 177 | self.hparams.codebook_size - 1, |
| 178 | out=None) |
| 179 | |
| 180 | if len(outputs[i]) > 1: |
| 181 | motion = self.vae.decode(outputs[i]) |
| 182 | else: |
| 183 | motion = torch.zeros_like(feats_ref[i:i + 1, ...]) |
| 184 | |
| 185 | min_len[i] = min(motion.shape[1], lengths[i]) |
| 186 | |
| 187 | # Cut Motion |
| 188 | feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] |
| 189 | |
| 190 | # Recover joints for evaluation |
| 191 | joints_ref = self.feats2joints(feats_ref) |
| 192 | joints_rst = self.feats2joints(feats_rst) |
| 193 | |
| 194 | # Renorm for evaluation |
| 195 | feats_ref = self.datamodule.renorm4t2m(feats_ref) |
no test coverage detected