MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / training_step

Method training_step

utils/commons/base_task.py:105–161  ·  view source on GitHub ↗

:param sample: :param batch_idx: :param optimizer_idx: :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict}

(self, sample, batch_idx, optimizer_idx=-1)

Source from the content-addressed store, hash-verified

103 raise NotImplementedError
104
105 def training_step(self, sample, batch_idx, optimizer_idx=-1):
106 """
107
108 :param sample:
109 :param batch_idx:
110 :param optimizer_idx:
111 :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict}
112 """
113 # perform the main training step in a specific task
114 loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
115 if loss_ret is None:
116 return {'loss': None}
117 total_loss, log_outputs = loss_ret
118 log_outputs = tensors_to_scalars(log_outputs)
119
120 # add to epoch meter
121 for k, v in log_outputs.items():
122 if '/' in k:
123 k_split = k.split("/")
124 assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `<tag>/<sub_tag>`"
125 k = k.replace("/", "_")
126 if k not in self.epoch_training_losses_meter:
127 self.epoch_training_losses_meter[k] = AvgrageMeter()
128 if not np.isnan(v):
129 self.epoch_training_losses_meter[k].update(v)
130
131 if optimizer_idx >= 0:
132 for params_group_i in range(len(self.trainer.optimizers[optimizer_idx].param_groups)):
133 log_outputs[f'lr/optimizer{optimizer_idx}_params_group{params_group_i}'] = self.trainer.optimizers[optimizer_idx].param_groups[params_group_i]['lr']
134
135 # add to progress bar
136 progress_bar_log = {}
137 for k, v in log_outputs.items():
138 if '/' in k:
139 k_split = k.split("/")
140 assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `<tag>/<sub_tag>`"
141 k = k.replace("/", "_")
142 assert k not in progress_bar_log, f"we got duplicate tags in log_outputs, check this `{k}`"
143 progress_bar_log[k] = v
144
145 # add to progress bar
146 tb_log = {}
147 for k, v in log_outputs.items():
148 if '/' in k:
149 tb_log[k] = v
150 else:
151 tb_log[f'tr/{k}'] = v
152
153 if not isinstance(total_loss, torch.Tensor):
154 return {'loss': None}
155 self.epoch_training_losses_meter['total_loss'].update(total_loss.item())
156
157 return {
158 'loss': total_loss,
159 'progress_bar': progress_bar_log,
160 'tb_log': tb_log
161 }
162

Callers 3

forwardMethod · 0.80
_run_ddp_forwardMethod · 0.80
run_training_batchMethod · 0.80

Calls 4

_training_stepMethod · 0.95
tensors_to_scalarsFunction · 0.90
AvgrageMeterClass · 0.90
updateMethod · 0.45

Tested by

no test coverage detected