MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / train_vae_forward

Method train_vae_forward

mGPT/models/mgpt.py:302–318  ·  view source on GitHub ↗
(self, batch)

Source from the content-addressed store, hash-verified

300 return rs_set
301
302 def train_vae_forward(self, batch):
303 # batch detach
304 feats_ref = batch["motion"]
305 joints_ref = self.feats2joints(feats_ref)
306 # motion encode & decode
307 feats_rst, loss_commit, perplexity = self.vae(feats_ref)
308 joints_rst = self.feats2joints(feats_rst)
309 # return set
310 rs_set = {
311 "m_ref": feats_ref,
312 "joints_ref": joints_ref,
313 "m_rst": feats_rst,
314 "joints_rst": joints_rst,
315 "loss_commit": loss_commit,
316 "perplexity": perplexity,
317 }
318 return rs_set
319
320 @torch.no_grad()
321 def val_vae_forward(self, batch, split="train"):

Callers 1

allsplit_stepMethod · 0.95

Calls 1

feats2jointsMethod · 0.45

Tested by

no test coverage detected