(self, true, pred, loss, lr, time_used, params, **kwargs)
| 140 | return time_per_epoch * (self._epoch_total - epoch_current) |
| 141 | |
| 142 | def update_stats(self, true, pred, loss, lr, time_used, params, **kwargs): |
| 143 | assert true.shape[0] == pred.shape[0] |
| 144 | self._iter += 1 |
| 145 | self._true.append(true) |
| 146 | self._pred.append(pred) |
| 147 | batch_size = true.shape[0] |
| 148 | self._size_current += batch_size |
| 149 | self._loss += loss * batch_size |
| 150 | self._lr = lr |
| 151 | self._params = params |
| 152 | self._time_used += time_used |
| 153 | self._time_total += time_used |
| 154 | for key, val in kwargs.items(): |
| 155 | if key not in self._custom_stats: |
| 156 | self._custom_stats[key] = val * batch_size |
| 157 | else: |
| 158 | self._custom_stats[key] += val * batch_size |
| 159 | |
| 160 | def write_iter(self): |
| 161 | raise NotImplementedError |
no outgoing calls
no test coverage detected