x: [sum(num_gen_tokens), V/H] num_gen_tokens: [gen_bs] gen_unpack_indxs: [bs, max(num_gen_tokens)] Returns: [gen_bs, max_gen_tokens, V/H] where max_gen_tokens = max(num_gen_tokens)
(x: Tensor, num_gen_tokens: Tensor,
gen_unpack_indxs: Tensor,
max_gen_tokens: Tensor)
| 632 | |
| 633 | |
| 634 | def _unpack_gen_data(x: Tensor, num_gen_tokens: Tensor, |
| 635 | gen_unpack_indxs: Tensor, |
| 636 | max_gen_tokens: Tensor) -> Tensor: |
| 637 | """ |
| 638 | x: [sum(num_gen_tokens), V/H] |
| 639 | num_gen_tokens: [gen_bs] |
| 640 | gen_unpack_indxs: [bs, max(num_gen_tokens)] |
| 641 | Returns: |
| 642 | [gen_bs, max_gen_tokens, V/H] where max_gen_tokens = max(num_gen_tokens) |
| 643 | """ |
| 644 | unpacked_x = index_select(x, dim=0, index=view(gen_unpack_indxs, [-1])) |
| 645 | out_shape = concat([ |
| 646 | shape(num_gen_tokens, 0, INT_DTYPE_STR), max_gen_tokens, |
| 647 | shape(x, -1, INT_DTYPE_STR) |
| 648 | ]) |
| 649 | return unpacked_x.view(out_shape, zero_is_placeholder=False) |
| 650 | |
| 651 | |
| 652 | def _process_logits_and_hidden_states( |
no test coverage detected