MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / forward

Method forward

mGPT/models/mgpt.py:65–121  ·  view source on GitHub ↗
(self, batch, task="t2m")

Source from the content-addressed store, hash-verified

63 self.codeFrequency = torch.zeros((self.hparams.codebook_size, ))
64
65 def forward(self, batch, task="t2m"):
66 texts = batch["text"]
67 lengths_ref = batch["length"]
68
69 # Forward
70 # texts = ['Generate motion: ' + text for text in texts]
71 outputs, output_texts = self.lm.generate_direct(texts, do_sample=True)
72
73 # Motion Decode
74 feats_rst_lst = []
75 lengths = []
76 max_len = 0
77
78 for i in range(len(texts)):
79 if task == "pred":
80 motion = self.vae.decode(
81 torch.cat((batch["motion"][i], outputs[i])))
82 elif task in ["t2m", "m2t", "inbetween"]:
83 motion = self.vae.decode(outputs[i])
84 # motion = self.datamodule.denormalize(motion)
85 lengths.append(motion.shape[1])
86 else:
87 raise NotImplementedError
88
89 if motion.shape[1] > max_len:
90 max_len = motion.shape[1]
91
92 if task in ["t2m", "m2t", "pred"]:
93 feats_rst_lst.append(motion)
94
95 elif task == "inbetween":
96 motion = torch.cat(
97 (batch["motion_heading"][i][None],
98 motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3,
99 ...], batch["motion_tailing"][i][None]),
100 dim=1)
101 feats_rst_lst.append(motion)
102
103 feats_rst = torch.zeros(
104 (len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device)
105
106 # padding and concat
107 for i in range(len(feats_rst_lst)):
108 feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i]
109
110 # Recover joints for evaluation
111 joints_rst = self.feats2joints(feats_rst)
112
113 # return set
114 outputs = {
115 "texts": output_texts,
116 "feats": feats_rst,
117 "joints": joints_rst,
118 "length": lengths
119 }
120
121 return outputs
122

Callers 1

predict_stepMethod · 0.45

Calls 4

generate_directMethod · 0.80
decodeMethod · 0.80
toMethod · 0.45
feats2jointsMethod · 0.45

Tested by

no test coverage detected