MCPcopy
hub / github.com/black0017/MedicalZooPytorch / train_epoch

Method train_epoch

lib/train/trainer.py:52–73  ·  view source on GitHub ↗
(self, epoch)

Source from the content-addressed store, hash-verified

50 self.writer.reset('val')
51
52 def train_epoch(self, epoch):
53 self.model.train()
54
55 for batch_idx, input_tuple in enumerate(self.train_data_loader):
56
57 self.optimizer.zero_grad()
58
59 input_tensor, target = prepare_input(input_tuple=input_tuple, args=self.args)
60 input_tensor.requires_grad = True
61 output = self.model(input_tensor)
62 loss_dice, per_ch_score = self.criterion(output, target)
63 loss_dice.backward()
64 self.optimizer.step()
65
66 self.writer.update_scores(batch_idx, loss_dice.item(), per_ch_score, 'train',
67 epoch * self.len_epoch + batch_idx)
68
69 if (batch_idx + 1) % self.terminal_show_freq == 0:
70 partial_epoch = epoch + batch_idx / self.len_epoch - 1
71 self.writer.display_terminal(partial_epoch, epoch, 'train')
72
73 self.writer.display_terminal(self.len_epoch, epoch, mode='train', summary=True)
74
75 def validate_epoch(self, epoch):
76 self.model.eval()

Callers 1

trainingMethod · 0.95

Calls 4

prepare_inputFunction · 0.90
trainMethod · 0.80
update_scoresMethod · 0.80
display_terminalMethod · 0.45

Tested by

no test coverage detected