| 149 | |
| 150 | |
| 151 | def process_input(input_ids_list: List[torch.Tensor], |
| 152 | token_type_ids_list: List[torch.Tensor], |
| 153 | is_roberta=False, |
| 154 | padding_idx=1): |
| 155 | input_lengths = [] |
| 156 | position_ids_list = [] |
| 157 | max_input_length = 0 |
| 158 | for i, input_ids in enumerate(input_ids_list): |
| 159 | input_len = len(input_ids) |
| 160 | assert input_len == len(token_type_ids_list[i]), f"sample {i}: len(input_ids)={len(input_ids)}, " \ |
| 161 | f"len(token_type_ids)={len(token_type_ids_list[i])}, not equal" |
| 162 | input_lengths.append(input_len) |
| 163 | position_ids = torch.arange(0, input_len, dtype=torch.int32) |
| 164 | if is_roberta: |
| 165 | position_ids = position_ids + 1 + padding_idx |
| 166 | |
| 167 | position_ids_list.append(position_ids) |
| 168 | max_input_length = max(max_input_length, input_len) |
| 169 | |
| 170 | # [num_tokens] |
| 171 | input_ids = torch.concat(input_ids_list).int().cuda() |
| 172 | token_type_ids = torch.concat(token_type_ids_list).int().cuda() |
| 173 | position_ids = torch.concat(position_ids_list).int().cuda() |
| 174 | |
| 175 | input_lengths = torch.tensor(input_lengths).int().cuda() # [batch_size] |
| 176 | max_input_length = torch.empty((max_input_length, )).int().cuda() |
| 177 | return input_ids, input_lengths, token_type_ids, position_ids, max_input_length |
| 178 | |
| 179 | |
| 180 | def intermediate_check(tllm_inter: Dict, hf_ref: Tuple[torch.Tensor], attn_mask, |