(self, batch)
| 300 | return rs_set |
| 301 | |
| 302 | def train_vae_forward(self, batch): |
| 303 | # batch detach |
| 304 | feats_ref = batch["motion"] |
| 305 | joints_ref = self.feats2joints(feats_ref) |
| 306 | # motion encode & decode |
| 307 | feats_rst, loss_commit, perplexity = self.vae(feats_ref) |
| 308 | joints_rst = self.feats2joints(feats_rst) |
| 309 | # return set |
| 310 | rs_set = { |
| 311 | "m_ref": feats_ref, |
| 312 | "joints_ref": joints_ref, |
| 313 | "m_rst": feats_rst, |
| 314 | "joints_rst": joints_rst, |
| 315 | "loss_commit": loss_commit, |
| 316 | "perplexity": perplexity, |
| 317 | } |
| 318 | return rs_set |
| 319 | |
| 320 | @torch.no_grad() |
| 321 | def val_vae_forward(self, batch, split="train"): |
no test coverage detected