Function
_tensor_copy_3dim
(
in_ptr,
in_stride_0,
in_stride_1,
in_stride_2,
out_ptr,
out_stride_0,
out_stride_1,
out_stride_2,
head_num,
head_dim,
total_len,
BLOCK_N: tl.constexpr,
)
Source from the content-addressed store, hash-verified
| 54 | |
| 55 | @triton.jit |
| 56 | def _tensor_copy_3dim( |
| 57 | in_ptr, |
| 58 | in_stride_0, |
| 59 | in_stride_1, |
| 60 | in_stride_2, |
| 61 | out_ptr, |
| 62 | out_stride_0, |
| 63 | out_stride_1, |
| 64 | out_stride_2, |
| 65 | head_num, |
| 66 | head_dim, |
| 67 | total_len, |
| 68 | BLOCK_N: tl.constexpr, |
| 69 | ): |
| 70 | start_index = tl.program_id(0) |
| 71 | grid_num = tl.num_programs(0) |
| 72 | |
| 73 | offs_d = tl.arange(0, BLOCK_N) |
| 74 | for cur_index in range(start_index, total_len, step=grid_num): |
| 75 | for cur_head in tl.range(head_num, num_stages=3): |
| 76 | in_tensor = tl.load( |
| 77 | in_ptr + in_stride_0 * cur_index + in_stride_1 * cur_head + offs_d, mask=offs_d < head_dim, other=0 |
| 78 | ) |
| 79 | tl.store( |
| 80 | out_ptr + out_stride_0 * cur_index + out_stride_1 * cur_head + offs_d, in_tensor, mask=offs_d < head_dim |
| 81 | ) |
| 82 | return |
| 83 | |
| 84 | |
| 85 | @torch.no_grad() |
Callers
nothing calls this directly
Tested by
no test coverage detected