MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / validation_end

Method validation_end

utils/commons/base_task.py:188–220  ·  view source on GitHub ↗

:param outputs: :return: loss_output: dict

(self, outputs)

Source from the content-addressed store, hash-verified

186 raise NotImplementedError
187
188 def validation_end(self, outputs):
189 """
190
191 :param outputs:
192 :return: loss_output: dict
193 """
194 all_losses_meter = {'total_loss': AvgrageMeter()}
195 for output in outputs:
196 if output is None or len(output) == 0:
197 continue
198 if isinstance(output, dict):
199 assert 'losses' in output, 'Key "losses" should exist in validation output.'
200 n = output.pop('nsamples', 1)
201 losses = tensors_to_scalars(output['losses'])
202 total_loss = output.get('total_loss', sum(losses.values()))
203 else:
204 assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)'
205 n = 1
206 total_loss, losses = output
207 losses = tensors_to_scalars(losses)
208 if isinstance(total_loss, torch.Tensor):
209 total_loss = total_loss.item()
210 for k, v in losses.items():
211 if k not in all_losses_meter:
212 all_losses_meter[k] = AvgrageMeter()
213 all_losses_meter[k].update(v, n)
214 all_losses_meter['total_loss'].update(total_loss, n)
215 loss_output = {k: round(v.avg, 10) for k, v in all_losses_meter.items()}
216 print(f"| Validation results@{self.global_step}: {loss_output}")
217 return {
218 'tb_log': {f'val/{k}': v for k, v in loss_output.items()},
219 'val_loss': loss_output['total_loss']
220 }
221
222 ######################
223 # testing

Callers 2

test_endMethod · 0.95
evaluateMethod · 0.45

Calls 3

AvgrageMeterClass · 0.90
tensors_to_scalarsFunction · 0.90
updateMethod · 0.45

Tested by 1

test_endMethod · 0.76