(self)
| 192 | torch.backends.cudnn.benchmark = False |
| 193 | |
| 194 | def train(self): |
| 195 | self.writer.write("===== Model =====") |
| 196 | self.writer.write(self.model) |
| 197 | |
| 198 | if "train" not in self.run_type: |
| 199 | self.inference() |
| 200 | return |
| 201 | |
| 202 | should_break = False |
| 203 | |
| 204 | if self.max_epochs is None: |
| 205 | self.max_epochs = math.inf |
| 206 | else: |
| 207 | self.max_iterations = math.inf |
| 208 | |
| 209 | self.model.train() |
| 210 | self.train_timer = Timer() |
| 211 | self.snapshot_timer = Timer() |
| 212 | |
| 213 | self.profile("Setup Time") |
| 214 | |
| 215 | torch.autograd.set_detect_anomaly(True) |
| 216 | |
| 217 | self.writer.write("Starting training...") |
| 218 | while self.current_iteration < self.max_iterations and not should_break: |
| 219 | self.current_epoch += 1 |
| 220 | registry.register("current_epoch", self.current_epoch) |
| 221 | |
| 222 | # Seed the sampler in case if it is distributed |
| 223 | self.task_loader.seed_sampler("train", self.current_epoch) |
| 224 | |
| 225 | if self.current_epoch > self.max_epochs: |
| 226 | break |
| 227 | |
| 228 | for batch in self.train_loader: |
| 229 | self.profile("Batch load time") |
| 230 | self.current_iteration += 1 |
| 231 | self.writer.write(self.current_iteration, "debug") |
| 232 | |
| 233 | registry.register("current_iteration", self.current_iteration) |
| 234 | |
| 235 | if self.current_iteration > self.max_iterations: |
| 236 | break |
| 237 | |
| 238 | self._run_scheduler() |
| 239 | report = self._forward_pass(batch) |
| 240 | self._update_meter(report, self.meter) |
| 241 | loss = self._extract_loss(report) |
| 242 | self._backward(loss) |
| 243 | should_break = self._logistics(report) |
| 244 | |
| 245 | if should_break: |
| 246 | break |
| 247 | |
| 248 | self.finalize() |
| 249 | |
| 250 | def _run_scheduler(self): |
| 251 | if self.lr_scheduler is not None: |
no test coverage detected