MCPcopy
hub / github.com/Morizeyao/GPT2-Chinese / build_files

Function build_files

train.py:15–38  ·  view source on GitHub ↗
(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length)

Source from the content-addressed store, hash-verified

13
14
15def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length):
16 with open(data_path, 'r', encoding='utf8') as f:
17 print('reading lines')
18 lines = json.load(f)
19 lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束
20 all_len = len(lines)
21 if not os.path.exists(tokenized_data_path):
22 os.mkdir(tokenized_data_path)
23 for i in tqdm(range(num_pieces)):
24 sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)]
25 if i == num_pieces - 1:
26 sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece
27 sublines = [full_tokenizer.tokenize(line) for line in sublines if
28 len(line) > min_length] # 只考虑长度超过min_length的句子
29 sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines]
30 full_line = []
31 for subline in sublines:
32 full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始
33 full_line.extend(subline)
34 full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束
35 with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f:
36 for id in full_line:
37 f.write(str(id) + ' ')
38 print('finish')
39
40
41def main():

Callers 1

mainFunction · 0.70

Calls 2

tokenizeMethod · 0.45
convert_tokens_to_idsMethod · 0.45

Tested by

no test coverage detected