Network training loop.
(self, max_iters)
| 82 | net.params['bbox_pred'][1].data[...] = orig_1 |
| 83 | |
| 84 | def train_model(self, max_iters): |
| 85 | """Network training loop.""" |
| 86 | last_snapshot_iter = -1 |
| 87 | timer = Timer() |
| 88 | while self.solver.iter < max_iters: |
| 89 | # Make one SGD update |
| 90 | timer.tic() |
| 91 | self.solver.step(1) |
| 92 | timer.toc() |
| 93 | if self.solver.iter % (10 * self.solver_param.display) == 0: |
| 94 | print 'speed: {:.3f}s / iter'.format(timer.average_time) |
| 95 | |
| 96 | if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: |
| 97 | last_snapshot_iter = self.solver.iter |
| 98 | self.snapshot() |
| 99 | |
| 100 | if last_snapshot_iter != self.solver.iter: |
| 101 | self.snapshot() |
| 102 | |
| 103 | def get_training_roidb(imdb): |
| 104 | """Returns a roidb (Region of Interest database) for use in training.""" |