(expt_hist, n_expts_tot, n_gates)
| 251 | |
| 252 | |
| 253 | def compute_expt_data(expt_hist, n_expts_tot, n_gates): |
| 254 | |
| 255 | if expt_hist is None: |
| 256 | return ExptData(None, None, None, None) |
| 257 | |
| 258 | # this just computes the kernel arguments: |
| 259 | token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal( |
| 260 | expt_hist, n_expts_tot, n_gates) |
| 261 | |
| 262 | _expt_data_memset[(blocks1, )]( |
| 263 | expt_hist, n_expts_tot, # |
| 264 | token_offs_combined, token_offs_combined.stride(0), # |
| 265 | block_pid_map, # |
| 266 | block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters |
| 267 | num_warps=4) |
| 268 | _expt_data_compute[(blocks2, )]( |
| 269 | expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs |
| 270 | block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters |
| 271 | num_warps=4) |
| 272 | |
| 273 | token_offs_pad = _unpack_into_dict(token_offs_pad) |
| 274 | block_pid_map = _unpack_into_dict(block_pid_map) |
| 275 | return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map) |
| 276 | |
| 277 | |
| 278 | # -------------------------- |
nothing calls this directly
no test coverage detected