:param outputs: :return: loss_output: dict
(self, outputs)
| 186 | raise NotImplementedError |
| 187 | |
| 188 | def validation_end(self, outputs): |
| 189 | """ |
| 190 | |
| 191 | :param outputs: |
| 192 | :return: loss_output: dict |
| 193 | """ |
| 194 | all_losses_meter = {'total_loss': AvgrageMeter()} |
| 195 | for output in outputs: |
| 196 | if output is None or len(output) == 0: |
| 197 | continue |
| 198 | if isinstance(output, dict): |
| 199 | assert 'losses' in output, 'Key "losses" should exist in validation output.' |
| 200 | n = output.pop('nsamples', 1) |
| 201 | losses = tensors_to_scalars(output['losses']) |
| 202 | total_loss = output.get('total_loss', sum(losses.values())) |
| 203 | else: |
| 204 | assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)' |
| 205 | n = 1 |
| 206 | total_loss, losses = output |
| 207 | losses = tensors_to_scalars(losses) |
| 208 | if isinstance(total_loss, torch.Tensor): |
| 209 | total_loss = total_loss.item() |
| 210 | for k, v in losses.items(): |
| 211 | if k not in all_losses_meter: |
| 212 | all_losses_meter[k] = AvgrageMeter() |
| 213 | all_losses_meter[k].update(v, n) |
| 214 | all_losses_meter['total_loss'].update(total_loss, n) |
| 215 | loss_output = {k: round(v.avg, 10) for k, v in all_losses_meter.items()} |
| 216 | print(f"| Validation results@{self.global_step}: {loss_output}") |
| 217 | return { |
| 218 | 'tb_log': {f'val/{k}': v for k, v in loss_output.items()}, |
| 219 | 'val_loss': loss_output['total_loss'] |
| 220 | } |
| 221 | |
| 222 | ###################### |
| 223 | # testing |