(seqA, seqB, max_seq_len, tokenizer)
| 81 | |
| 82 | |
| 83 | def encode_sequence(seqA, seqB, max_seq_len, tokenizer): |
| 84 | seqA = ["[CLS]"] + seqA + ["[SEP]"] |
| 85 | seqB = seqB + ["[SEP]"] |
| 86 | |
| 87 | input_tokens = seqA + seqB |
| 88 | input_ids = tokenizer.convert_tokens_to_ids(input_tokens) |
| 89 | sequence_ids = [0]*len(seqA) + [1]*len(seqB) |
| 90 | input_mask = [1]*len(input_ids) |
| 91 | |
| 92 | while len(input_ids) < max_seq_len: |
| 93 | input_ids.append(PAD) |
| 94 | sequence_ids.append(PAD) |
| 95 | input_mask.append(PAD) |
| 96 | |
| 97 | return (map_to_torch(input_ids), map_to_torch(input_mask), map_to_torch(sequence_ids)) |
| 98 | |
| 99 | |
| 100 | def truncate_input_sequence(tokens_a, tokens_b, max_num_tokens): |
no test coverage detected