(self, partials_block_size)
| 184 | self.scratchpad = scratchpad |
| 185 | |
| 186 | def sum(self, partials_block_size): |
| 187 | _, n_cols = self.shape |
| 188 | dev = self.device |
| 189 | if self.scratchpad is None: |
| 190 | self.scratchpad = clear_sums(n_cols, dev) |
| 191 | out_ret = self.scratchpad[:n_cols] |
| 192 | self.scratchpad = None # throw error if we try to sum again |
| 193 | return sum_bitmatrix_rows(self, out_ret, partials_block_size) |
| 194 | |
| 195 | |
| 196 | def get_layout(tensor: torch.Tensor | Tensor | None): |