MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / val_t2m_forward

Method val_t2m_forward

mGPT/models/mgpt.py:138–208  ·  view source on GitHub ↗
(self, batch)

Source from the content-addressed store, hash-verified

136
137 @torch.no_grad()
138 def val_t2m_forward(self, batch):
139 feats_ref = batch["motion"]
140 texts = batch["text"]
141 lengths = batch["length"]
142 tasks = None
143 if self.trainer.datamodule.is_mm:
144 texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS
145 feats_ref = feats_ref.repeat_interleave(
146 self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0)
147 lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS
148 instructions = pjoin(self.datamodule.hparams.data_root,
149 'template_instructions.json')
150 instructions = json.load(open(instructions, 'r'))
151 tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts)
152
153 if self.hparams.condition == 'caption':
154 tasks = [{
155 'input': ['<Caption_Placeholder>'],
156 'output': ['']
157 }] * len(texts)
158
159 if self.hparams.cfg.DATASET.TASK_PATH:
160 instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH)
161 instructions = json.load(open(instructions, 'r'))
162 tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts)
163
164 min_len = lengths.copy()
165 # Forward
166 outputs = self.lm.generate_conditional(texts,
167 lengths=lengths,
168 stage='test',
169 tasks=tasks)
170
171 # Motion Decode
172 feats_rst = torch.zeros_like(feats_ref)
173
174 for i in range(len(texts)):
175 outputs[i] = torch.clamp(outputs[i],
176 0,
177 self.hparams.codebook_size - 1,
178 out=None)
179
180 if len(outputs[i]) > 1:
181 motion = self.vae.decode(outputs[i])
182 else:
183 motion = torch.zeros_like(feats_ref[i:i + 1, ...])
184
185 min_len[i] = min(motion.shape[1], lengths[i])
186
187 # Cut Motion
188 feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]]
189
190 # Recover joints for evaluation
191 joints_ref = self.feats2joints(feats_ref)
192 joints_rst = self.feats2joints(feats_rst)
193
194 # Renorm for evaluation
195 feats_ref = self.datamodule.renorm4t2m(feats_ref)

Callers 1

allsplit_stepMethod · 0.95

Calls 4

generate_conditionalMethod · 0.80
decodeMethod · 0.80
feats2jointsMethod · 0.45
renorm4t2mMethod · 0.45

Tested by

no test coverage detected