| 11 | |
| 12 | |
| 13 | def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): |
| 14 | if not isinstance(x, Tensor): |
| 15 | x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]] |
| 16 | x_shape_max = [x.shape[0], x.shape[1]] |
| 17 | x = Tensor(x, shape=x_shape, shape_max=x_shape_max) |
| 18 | cdiv = lambda a, b: (a + b - 1) // b |
| 19 | BLOCK_M = 32 |
| 20 | BLOCK_N = 32 |
| 21 | BLOCK_S = 128 |
| 22 | assert len(x.shape) == 2 |
| 23 | assert x.shape_max[-1] < 32768 |
| 24 | assert dim == 1 |
| 25 | assert return_bitmatrix |
| 26 | n_rows, n_cols = x.shape |
| 27 | n_rows_max, _ = x.shape_max |
| 28 | dev = x.device |
| 29 | # scratchpad tensors |
| 30 | # NOTE: these are not returned |
| 31 | y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev) |
| 32 | if y_indx is not None: |
| 33 | use_provided_indx = True |
| 34 | else: |
| 35 | y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev) |
| 36 | use_provided_indx = False |
| 37 | # create bitmatrix in transposed memory layout: |
| 38 | n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N |
| 39 | n_cols_words = n_cols_pad // 32 |
| 40 | bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev) |
| 41 | bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max] |
| 42 | s_blocks = cdiv(n_cols, BLOCK_S) |
| 43 | s_cols = s_blocks * BLOCK_S |
| 44 | scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev) |
| 45 | pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks) |
| 46 | _topk_forward[(pids, )]( |
| 47 | x, x.stride(0), # inputs |
| 48 | y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk] |
| 49 | bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix] |
| 50 | n_rows, n_cols, # shapes |
| 51 | scratchpad, BLOCK_S, s_blocks, # thing to memset to zero |
| 52 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter |
| 53 | APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants |
| 54 | ) |
| 55 | bitmatrix_shape = [n_rows, n_cols_words * 32] |
| 56 | bitmatrix_shape_max = [n_rows_max, None] |
| 57 | bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad) |
| 58 | return y_vals, y_indx, bitmatrix |
| 59 | |
| 60 | |
| 61 | def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax): |