Represents a boolean matrix in a packed format where each element occupies a single bit of memory. _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along with the actual bitmatrix to avoid having to launch a separate memset kernel when we call Bi
| 168 | |
| 169 | @dataclass |
| 170 | class Bitmatrix(Tensor): |
| 171 | """ |
| 172 | Represents a boolean matrix in a packed format where each element occupies |
| 173 | a single bit of memory. |
| 174 | |
| 175 | _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along |
| 176 | with the actual bitmatrix to avoid having to launch a separate memset |
| 177 | kernel when we call Bitmatrix::sum(). |
| 178 | """ |
| 179 | |
| 180 | scratchpad: torch.Tensor = None |
| 181 | |
| 182 | def __init__(self, storage, shape, shape_max=None, scratchpad=None): |
| 183 | super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max) |
| 184 | self.scratchpad = scratchpad |
| 185 | |
| 186 | def sum(self, partials_block_size): |
| 187 | _, n_cols = self.shape |
| 188 | dev = self.device |
| 189 | if self.scratchpad is None: |
| 190 | self.scratchpad = clear_sums(n_cols, dev) |
| 191 | out_ret = self.scratchpad[:n_cols] |
| 192 | self.scratchpad = None # throw error if we try to sum again |
| 193 | return sum_bitmatrix_rows(self, out_ret, partials_block_size) |
| 194 | |
| 195 | |
| 196 | def get_layout(tensor: torch.Tensor | Tensor | None): |