Full training logic
(self)
| 58 | raise NotImplementedError |
| 59 | |
| 60 | def train(self): |
| 61 | """ |
| 62 | Full training logic |
| 63 | """ |
| 64 | not_improved_count = 0 |
| 65 | for epoch in range(self.start_epoch, self.epochs + 1): |
| 66 | result = self._train_epoch(epoch) |
| 67 | |
| 68 | # save logged informations into log dict |
| 69 | log = {'epoch': epoch} |
| 70 | log.update(result) |
| 71 | |
| 72 | # print logged informations to the screen |
| 73 | for key, value in log.items(): |
| 74 | self.logger.info(' {:15s}: {}'.format(str(key), value)) |
| 75 | |
| 76 | # evaluate model performance according to configured metric, save best checkpoint as model_best |
| 77 | best = False |
| 78 | if self.mnt_mode != 'off': |
| 79 | try: |
| 80 | # check whether model performance improved or not, according to specified metric(mnt_metric) |
| 81 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ |
| 82 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) |
| 83 | except KeyError: |
| 84 | self.logger.warning("Warning: Metric '{}' is not found. " |
| 85 | "Model performance monitoring is disabled.".format(self.mnt_metric)) |
| 86 | self.mnt_mode = 'off' |
| 87 | improved = False |
| 88 | |
| 89 | if improved: |
| 90 | self.mnt_best = log[self.mnt_metric] |
| 91 | not_improved_count = 0 |
| 92 | best = True |
| 93 | else: |
| 94 | not_improved_count += 1 |
| 95 | |
| 96 | if not_improved_count > self.early_stop: |
| 97 | self.logger.info("Validation performance didn\'t improve for {} epochs. " |
| 98 | "Training stops.".format(self.early_stop)) |
| 99 | break |
| 100 | |
| 101 | if epoch % self.save_period == 0: |
| 102 | self._save_checkpoint(epoch, save_best=best) |
| 103 | |
| 104 | def _prepare_device(self, n_gpu_use): |
| 105 | """ |
no test coverage detected