| 34 | |
| 35 | |
| 36 | class MomentumBuffer: |
| 37 | def __init__(self, momentum: float): |
| 38 | self.momentum = momentum |
| 39 | self.running_average = 0 |
| 40 | |
| 41 | def update(self, update_value: torch.Tensor): |
| 42 | new_average = self.momentum * self.running_average |
| 43 | self.running_average = update_value + new_average |
| 44 | |
| 45 | |
| 46 | def project(v0: torch.Tensor, v1: torch.Tensor): |
no outgoing calls
no test coverage detected