(self, numel: int, numel_padded: int, dtype: torch.dtype)
| 28 | """ |
| 29 | |
| 30 | def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype): |
| 31 | self.numel = numel |
| 32 | self.numel_padded = numel_padded |
| 33 | self.dtype = dtype |
| 34 | self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) |
| 35 | |
| 36 | def zero(self): |
| 37 | """Reset the buffer to zero.""" |