(self, batch, task="t2m")
| 63 | self.codeFrequency = torch.zeros((self.hparams.codebook_size, )) |
| 64 | |
| 65 | def forward(self, batch, task="t2m"): |
| 66 | texts = batch["text"] |
| 67 | lengths_ref = batch["length"] |
| 68 | |
| 69 | # Forward |
| 70 | # texts = ['Generate motion: ' + text for text in texts] |
| 71 | outputs, output_texts = self.lm.generate_direct(texts, do_sample=True) |
| 72 | |
| 73 | # Motion Decode |
| 74 | feats_rst_lst = [] |
| 75 | lengths = [] |
| 76 | max_len = 0 |
| 77 | |
| 78 | for i in range(len(texts)): |
| 79 | if task == "pred": |
| 80 | motion = self.vae.decode( |
| 81 | torch.cat((batch["motion"][i], outputs[i]))) |
| 82 | elif task in ["t2m", "m2t", "inbetween"]: |
| 83 | motion = self.vae.decode(outputs[i]) |
| 84 | # motion = self.datamodule.denormalize(motion) |
| 85 | lengths.append(motion.shape[1]) |
| 86 | else: |
| 87 | raise NotImplementedError |
| 88 | |
| 89 | if motion.shape[1] > max_len: |
| 90 | max_len = motion.shape[1] |
| 91 | |
| 92 | if task in ["t2m", "m2t", "pred"]: |
| 93 | feats_rst_lst.append(motion) |
| 94 | |
| 95 | elif task == "inbetween": |
| 96 | motion = torch.cat( |
| 97 | (batch["motion_heading"][i][None], |
| 98 | motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3, |
| 99 | ...], batch["motion_tailing"][i][None]), |
| 100 | dim=1) |
| 101 | feats_rst_lst.append(motion) |
| 102 | |
| 103 | feats_rst = torch.zeros( |
| 104 | (len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device) |
| 105 | |
| 106 | # padding and concat |
| 107 | for i in range(len(feats_rst_lst)): |
| 108 | feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i] |
| 109 | |
| 110 | # Recover joints for evaluation |
| 111 | joints_rst = self.feats2joints(feats_rst) |
| 112 | |
| 113 | # return set |
| 114 | outputs = { |
| 115 | "texts": output_texts, |
| 116 | "feats": feats_rst, |
| 117 | "joints": joints_rst, |
| 118 | "length": lengths |
| 119 | } |
| 120 | |
| 121 | return outputs |
| 122 |
no test coverage detected