(self)
| 205 | self.data.dataset.t2m_dataset.opt.joints_num, self.model.all_goal_joint_names, cond['target_joint_names'], cond['is_heading']).detach() |
| 206 | |
| 207 | def run_loop(self): |
| 208 | print('train steps:', self.num_steps) |
| 209 | for epoch in range(self.num_epochs): |
| 210 | print(f'Starting epoch {epoch}') |
| 211 | for motion, cond in tqdm(self.data): |
| 212 | if not (not self.lr_anneal_steps or self.total_step() < self.lr_anneal_steps): |
| 213 | break |
| 214 | |
| 215 | self.cond_modifiers(cond['y'], motion) # Modify in-place for efficiency |
| 216 | motion = motion.to(self.device) |
| 217 | cond['y'] = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in cond['y'].items()} |
| 218 | |
| 219 | self.run_step(motion, cond) |
| 220 | if self.total_step() % self.log_interval == 0: |
| 221 | for k,v in logger.get_current().dumpkvs().items(): |
| 222 | if k == 'loss': |
| 223 | print('step[{}]: loss[{:0.5f}]'.format(self.total_step(), v)) |
| 224 | |
| 225 | if k in ['step', 'samples'] or '_q' in k: |
| 226 | continue |
| 227 | else: |
| 228 | self.train_platform.report_scalar(name=k, value=v, iteration=self.total_step(), group_name='Loss') |
| 229 | |
| 230 | if self.total_step() % self.save_interval == 0: |
| 231 | self.save() |
| 232 | self.model.eval() |
| 233 | if self.args.use_ema: |
| 234 | self.model_avg.eval() |
| 235 | self.evaluate() |
| 236 | self.generate_during_training() |
| 237 | self.model.train() |
| 238 | if self.args.use_ema: |
| 239 | self.model_avg.train() |
| 240 | |
| 241 | # Run for a finite amount of time in integration tests. |
| 242 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.total_step() > 0: |
| 243 | return |
| 244 | self.step += 1 |
| 245 | if not (not self.lr_anneal_steps or self.total_step() < self.lr_anneal_steps): |
| 246 | break |
| 247 | # Save the last checkpoint if it wasn't already saved. |
| 248 | if (self.total_step() - 1) % self.save_interval != 0: |
| 249 | self.save() |
| 250 | self.evaluate() |
| 251 | |
| 252 | def evaluate(self): |
| 253 | if not self.args.eval_during_training: |
no test coverage detected