| 151 | self.opt.load_state_dict(state_dict) |
| 152 | |
| 153 | def run_loop(self): |
| 154 | while ( |
| 155 | not self.lr_anneal_steps |
| 156 | or self.step + self.resume_step < self.lr_anneal_steps |
| 157 | ): |
| 158 | batch, cond = next(self.data) |
| 159 | self.run_step(batch, cond) |
| 160 | if self.step % self.log_interval == 0: |
| 161 | logger.dumpkvs() |
| 162 | if self.step % self.save_interval == 0: |
| 163 | self.save() |
| 164 | # Run for a finite amount of time in integration tests. |
| 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: |
| 166 | return |
| 167 | self.step += 1 |
| 168 | # Save the last checkpoint if it wasn't already saved. |
| 169 | if (self.step - 1) % self.save_interval != 0: |
| 170 | self.save() |
| 171 | |
| 172 | def run_step(self, batch, cond): |
| 173 | self.forward_backward(batch, cond) |