(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length)
| 13 | |
| 14 | |
| 15 | def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length): |
| 16 | with open(data_path, 'r', encoding='utf8') as f: |
| 17 | print('reading lines') |
| 18 | lines = json.load(f) |
| 19 | lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 |
| 20 | all_len = len(lines) |
| 21 | if not os.path.exists(tokenized_data_path): |
| 22 | os.mkdir(tokenized_data_path) |
| 23 | for i in tqdm(range(num_pieces)): |
| 24 | sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)] |
| 25 | if i == num_pieces - 1: |
| 26 | sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece |
| 27 | sublines = [full_tokenizer.tokenize(line) for line in sublines if |
| 28 | len(line) > min_length] # 只考虑长度超过min_length的句子 |
| 29 | sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines] |
| 30 | full_line = [] |
| 31 | for subline in sublines: |
| 32 | full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始 |
| 33 | full_line.extend(subline) |
| 34 | full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束 |
| 35 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: |
| 36 | for id in full_line: |
| 37 | f.write(str(id) + ' ') |
| 38 | print('finish') |
| 39 | |
| 40 | |
| 41 | def main(): |
no test coverage detected