A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying memory. It must have a unique type to support this behavior.
| 22 | |
| 23 | |
| 24 | class MemoryBuffer: |
| 25 | """ |
| 26 | A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying |
| 27 | memory. It must have a unique type to support this behavior. |
| 28 | """ |
| 29 | |
| 30 | def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype): |
| 31 | self.numel = numel |
| 32 | self.numel_padded = numel_padded |
| 33 | self.dtype = dtype |
| 34 | self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device='cuda', requires_grad=False) |
| 35 | |
| 36 | def zero(self): |
| 37 | """Reset the buffer to zero.""" |
| 38 | self.data.zero_() |
| 39 | |
| 40 | def get(self, shape, start_index): |
| 41 | """Return a tensor with the input `shape` as a view into the |
| 42 | 1-D data starting at `start_index`.""" |
| 43 | end_index = start_index + shape.numel() |
| 44 | assert end_index <= self.numel, \ |
| 45 | 'requested tensor is out of the buffer range.' |
| 46 | buffer_tensor = self.data[start_index:end_index] |
| 47 | buffer_tensor = buffer_tensor.view(shape) |
| 48 | return buffer_tensor |
| 49 | |
| 50 | |
| 51 | def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): |
no outgoing calls
no test coverage detected