MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / compute_expt_data

Function compute_expt_data

triton_kernels/routing.py:253–275  ·  view source on GitHub ↗
(expt_hist, n_expts_tot, n_gates)

Source from the content-addressed store, hash-verified

251
252
253def compute_expt_data(expt_hist, n_expts_tot, n_gates):
254
255 if expt_hist is None:
256 return ExptData(None, None, None, None)
257
258 # this just computes the kernel arguments:
259 token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
260 expt_hist, n_expts_tot, n_gates)
261
262 _expt_data_memset[(blocks1, )](
263 expt_hist, n_expts_tot, #
264 token_offs_combined, token_offs_combined.stride(0), #
265 block_pid_map, #
266 block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters
267 num_warps=4)
268 _expt_data_compute[(blocks2, )](
269 expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
270 block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters
271 num_warps=4)
272
273 token_offs_pad = _unpack_into_dict(token_offs_pad)
274 block_pid_map = _unpack_into_dict(block_pid_map)
275 return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
276
277
278# --------------------------

Callers

nothing calls this directly

Calls 4

ExptDataClass · 0.85
_unpack_into_dictFunction · 0.85
strideMethod · 0.80

Tested by

no test coverage detected