MCPcopy Index your code
hub / github.com/GuyTevet/motion-diffusion-model / run_loop

Method run_loop

train/training_loop.py:207–250  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

205 self.data.dataset.t2m_dataset.opt.joints_num, self.model.all_goal_joint_names, cond['target_joint_names'], cond['is_heading']).detach()
206
207 def run_loop(self):
208 print('train steps:', self.num_steps)
209 for epoch in range(self.num_epochs):
210 print(f'Starting epoch {epoch}')
211 for motion, cond in tqdm(self.data):
212 if not (not self.lr_anneal_steps or self.total_step() < self.lr_anneal_steps):
213 break
214
215 self.cond_modifiers(cond['y'], motion) # Modify in-place for efficiency
216 motion = motion.to(self.device)
217 cond['y'] = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in cond['y'].items()}
218
219 self.run_step(motion, cond)
220 if self.total_step() % self.log_interval == 0:
221 for k,v in logger.get_current().dumpkvs().items():
222 if k == 'loss':
223 print('step[{}]: loss[{:0.5f}]'.format(self.total_step(), v))
224
225 if k in ['step', 'samples'] or '_q' in k:
226 continue
227 else:
228 self.train_platform.report_scalar(name=k, value=v, iteration=self.total_step(), group_name='Loss')
229
230 if self.total_step() % self.save_interval == 0:
231 self.save()
232 self.model.eval()
233 if self.args.use_ema:
234 self.model_avg.eval()
235 self.evaluate()
236 self.generate_during_training()
237 self.model.train()
238 if self.args.use_ema:
239 self.model_avg.train()
240
241 # Run for a finite amount of time in integration tests.
242 if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.total_step() > 0:
243 return
244 self.step += 1
245 if not (not self.lr_anneal_steps or self.total_step() < self.lr_anneal_steps):
246 break
247 # Save the last checkpoint if it wasn't already saved.
248 if (self.total_step() - 1) % self.save_interval != 0:
249 self.save()
250 self.evaluate()
251
252 def evaluate(self):
253 if not self.args.eval_during_training:

Callers 1

mainFunction · 0.80

Calls 10

total_stepMethod · 0.95
cond_modifiersMethod · 0.95
run_stepMethod · 0.95
saveMethod · 0.95
evaluateMethod · 0.95
dumpkvsMethod · 0.80
toMethod · 0.45
report_scalarMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected