(text_a_ids, text_b_ids, answer_ids, max_seq_length, tokenizer, args=None, add_cls=True,
add_sep=False, add_piece=False, add_eos=True, mask_id=None)
| 142 | |
| 143 | |
| 144 | def build_input_from_ids(text_a_ids, text_b_ids, answer_ids, max_seq_length, tokenizer, args=None, add_cls=True, |
| 145 | add_sep=False, add_piece=False, add_eos=True, mask_id=None): |
| 146 | if mask_id is None: |
| 147 | mask_id = tokenizer.get_command('MASK').Id |
| 148 | eos_id = tokenizer.get_command('eos').Id |
| 149 | cls_id = tokenizer.get_command('ENC').Id |
| 150 | sep_id = tokenizer.get_command('sep').Id |
| 151 | ids = [] |
| 152 | types = [] |
| 153 | paddings = [] |
| 154 | # CLS |
| 155 | if add_cls: |
| 156 | ids.append(cls_id) |
| 157 | types.append(0) |
| 158 | paddings.append(1) |
| 159 | # A |
| 160 | len_text_a = len(text_a_ids) |
| 161 | ids.extend(text_a_ids) |
| 162 | types.extend([0] * len_text_a) |
| 163 | paddings.extend([1] * len_text_a) |
| 164 | # B |
| 165 | if text_b_ids is not None: |
| 166 | # SEP |
| 167 | if add_sep: |
| 168 | ids.append(sep_id) |
| 169 | types.append(0) |
| 170 | paddings.append(1) |
| 171 | len_text_b = len(text_b_ids) |
| 172 | ids.extend(text_b_ids) |
| 173 | types.extend([1] * len_text_b) |
| 174 | paddings.extend([1] * len_text_b) |
| 175 | eos_length = 1 if add_eos else 0 |
| 176 | # Cap the size. |
| 177 | if len(ids) >= max_seq_length - eos_length: |
| 178 | max_seq_length_m1 = max_seq_length - 1 |
| 179 | ids = ids[0:max_seq_length_m1] |
| 180 | types = types[0:max_seq_length_m1] |
| 181 | paddings = paddings[0:max_seq_length_m1] |
| 182 | end_type = 0 if text_b_ids is None else 1 |
| 183 | if add_eos: |
| 184 | ids.append(eos_id) |
| 185 | types.append(end_type) |
| 186 | paddings.append(1) |
| 187 | sep = len(ids) |
| 188 | target_ids = [0] * len(ids) |
| 189 | loss_masks = [0] * len(ids) |
| 190 | position_ids = list(range(len(ids))) |
| 191 | block_position_ids = [0] * len(ids) |
| 192 | # Piece |
| 193 | if add_piece or answer_ids is not None: |
| 194 | sop_id = tokenizer.get_command('sop').Id |
| 195 | assert mask_id in ids |
| 196 | mask_position = len(ids) - ids[-1::-1].index( |
| 197 | mask_id) - 1 if not args.sentinel_token else args.max_position_embeddings |
| 198 | ids.append(sop_id) |
| 199 | types.append(end_type) |
| 200 | paddings.append(1) |
| 201 | position_ids.append(mask_position) |
no test coverage detected