(self, y_pred: torch.Tensor, y_real: torch.Tensor)
| 30 | self.reduce = reduce |
| 31 | |
| 32 | def update(self, y_pred: torch.Tensor, y_real: torch.Tensor): |
| 33 | elements = y_pred.shape[0] |
| 34 | if elements != y_real.shape[0]: |
| 35 | raise Exception( |
| 36 | 'Can not update measurement, cause your input data do not share a same batchsize. ' |
| 37 | f'Shape of y_pred {y_pred.shape} - against shape of y_real {y_real.shape}') |
| 38 | result = self.measure_fn(y_pred=y_pred, y_real=y_real).item() |
| 39 | |
| 40 | if self.reduce == 'mean': |
| 41 | self.measure = self.measure * self.num_of_elements + result * elements |
| 42 | self.num_of_elements += elements |
| 43 | self.measure /= self.num_of_elements |
| 44 | |
| 45 | if self.reduce == 'max': |
| 46 | self.measure = max(self.measure, result) |
| 47 | self.num_of_elements += elements |
| 48 | |
| 49 | class MeasurePrinter(): |
| 50 | """Helper class for print top-k record.""" |
no outgoing calls
no test coverage detected