MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

triton_kernels/routing.py:100–154  ·  view source on GitHub ↗
(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix)

Source from the content-addressed store, hash-verified

98
99 @staticmethod
100 def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
101 HIST_BLOCK_M = 32
102 INDX_OFFS_BLOCK_M = 512
103 MEMSET_BLOCK = 1024
104 cdiv = triton.cdiv
105
106 device = expt_scal.device
107 dtype = expt_scal.dtype
108 n_tokens_raw, _ = bitmatrix.shape
109 n_tokens_pad, n_expts_act = expt_scal.shape
110 n_gates_pad = n_tokens_pad * n_expts_act
111
112 hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
113 hist = hist[:n_expts_tot]
114 assert hist.dtype == torch.int32
115 # scratchpad
116 expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
117 combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
118 # output
119 topk_indx = combined_indx[:n_gates_pad]
120 gate_indx = combined_indx[n_gates_pad:]
121 gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
122
123 token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
124 hist, n_expts_tot, n_gates_pad)
125
126 blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
127 blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
128
129 _combined_routing_memset[(blocks1a + blocks1b, )](
130 combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, #
131 expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs
132 partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
133 token_offs_combined, token_offs_combined.stride(0), #
134 blocks1a, block_pid_map, #
135 block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
136 BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
137 )
138
139 indx_offs = partial_hist
140
141 _combined_routing_compute[(blocks2a + blocks2b, )](
142 topk_indx, gate_indx, gate_scal, # outputs
143 expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs
144 expt_offs, n_tokens_raw, # input shape
145 HIST_BLOCK_M, n_expts_act, # constants
146 hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
147 block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc.
148 )
149
150 ctx.n_tokens_raw = n_tokens_raw
151 ctx.n_tokens_pad = n_tokens_pad
152 ctx.n_expts_act = n_expts_act
153 ctx.save_for_backward(gate_indx)
154 return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map
155
156 @staticmethod
157 def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):

Callers

nothing calls this directly

Calls 5

cdivFunction · 0.85
sumMethod · 0.80
strideMethod · 0.80
emptyMethod · 0.45

Tested by

no test coverage detected