MCPcopy
hub / github.com/black0017/MedicalZooPytorch / train

Method train

lib/train/BaseTrainer.py:60–102  ·  view source on GitHub ↗

Full training logic

(self)

Source from the content-addressed store, hash-verified

58 raise NotImplementedError
59
60 def train(self):
61 """
62 Full training logic
63 """
64 not_improved_count = 0
65 for epoch in range(self.start_epoch, self.epochs + 1):
66 result = self._train_epoch(epoch)
67
68 # save logged informations into log dict
69 log = {'epoch': epoch}
70 log.update(result)
71
72 # print logged informations to the screen
73 for key, value in log.items():
74 self.logger.info(' {:15s}: {}'.format(str(key), value))
75
76 # evaluate model performance according to configured metric, save best checkpoint as model_best
77 best = False
78 if self.mnt_mode != 'off':
79 try:
80 # check whether model performance improved or not, according to specified metric(mnt_metric)
81 improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
82 (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
83 except KeyError:
84 self.logger.warning("Warning: Metric '{}' is not found. "
85 "Model performance monitoring is disabled.".format(self.mnt_metric))
86 self.mnt_mode = 'off'
87 improved = False
88
89 if improved:
90 self.mnt_best = log[self.mnt_metric]
91 not_improved_count = 0
92 best = True
93 else:
94 not_improved_count += 1
95
96 if not_improved_count > self.early_stop:
97 self.logger.info("Validation performance didn\'t improve for {} epochs. "
98 "Training stops.".format(self.early_stop))
99 break
100
101 if epoch % self.save_period == 0:
102 self._save_checkpoint(epoch, save_best=best)
103
104 def _prepare_device(self, n_gpu_use):
105 """

Callers 3

trainFunction · 0.80
train_epochMethod · 0.80
train_diceFunction · 0.80

Calls 2

_train_epochMethod · 0.95
updateMethod · 0.80

Tested by

no test coverage detected