MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / topk_forward

Function topk_forward

triton_kernels/topk.py:13–58  ·  view source on GitHub ↗
(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None)

Source from the content-addressed store, hash-verified

11
12
13def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
14 if not isinstance(x, Tensor):
15 x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
16 x_shape_max = [x.shape[0], x.shape[1]]
17 x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
18 cdiv = lambda a, b: (a + b - 1) // b
19 BLOCK_M = 32
20 BLOCK_N = 32
21 BLOCK_S = 128
22 assert len(x.shape) == 2
23 assert x.shape_max[-1] < 32768
24 assert dim == 1
25 assert return_bitmatrix
26 n_rows, n_cols = x.shape
27 n_rows_max, _ = x.shape_max
28 dev = x.device
29 # scratchpad tensors
30 # NOTE: these are not returned
31 y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
32 if y_indx is not None:
33 use_provided_indx = True
34 else:
35 y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
36 use_provided_indx = False
37 # create bitmatrix in transposed memory layout:
38 n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
39 n_cols_words = n_cols_pad // 32
40 bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev)
41 bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
42 s_blocks = cdiv(n_cols, BLOCK_S)
43 s_cols = s_blocks * BLOCK_S
44 scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev)
45 pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
46 _topk_forward[(pids, )](
47 x, x.stride(0), # inputs
48 y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk]
49 bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix]
50 n_rows, n_cols, # shapes
51 scratchpad, BLOCK_S, s_blocks, # thing to memset to zero
52 BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
53 APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
54 )
55 bitmatrix_shape = [n_rows, n_cols_words * 32]
56 bitmatrix_shape_max = [n_rows_max, None]
57 bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad)
58 return y_vals, y_indx, bitmatrix
59
60
61def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):

Callers 1

forwardMethod · 0.85

Calls 7

strideMethod · 0.95
TensorClass · 0.90
BitmatrixClass · 0.90
cdivFunction · 0.85
maxFunction · 0.85
transposeMethod · 0.80
emptyMethod · 0.45

Tested by

no test coverage detected