MCPcopy
hub / github.com/Morizeyao/GPT2-Chinese / main

Function main

eval.py:39–180  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

37
38
39def main():
40 parser = argparse.ArgumentParser()
41 parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
42 parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
43 help='选择模型参数')
44 parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库')
45 parser.add_argument('--raw_data_path', default='data/eval.json', type=str, required=False, help='原始语料')
46 parser.add_argument('--tokenized_data_path', default='data/tokenized_eval/', type=str, required=False,
47 help='tokenized语料存放位置')
48 parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
49 parser.add_argument('--batch_size', default=8, type=int, required=False, help='batch size')
50 parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次')
51 parser.add_argument('--stride', default=768, type=int, required=False, help='取数据的窗口步长')
52 parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份')
53 parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度')
54 parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型起点路径')
55 parser.add_argument('--output_dir', default='eval_result/', type=str, required=False, help='结果输出路径')
56
57 args = parser.parse_args()
58 print('args:\n' + args.__repr__())
59
60 # if args.no_wordpiece:
61 # from tokenizations import tokenization_bert_without_wordpiece as tokenization_bert
62 # else:
63 from tokenizations import tokenization_bert
64
65 os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
66
67 model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
68 print('config:\n' + model_config.to_json_string())
69
70 n_ctx = model_config.n_ctx
71 full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
72 full_tokenizer.max_len = n_ctx
73 device = 'cuda' if torch.cuda.is_available() else 'cpu'
74 print('using device:', device)
75
76 raw_data_path = args.raw_data_path
77 tokenized_data_path = args.tokenized_data_path
78 raw = args.raw # 选择是否从零开始构建数据集
79 batch_size = args.batch_size
80 log_step = args.log_step
81 stride = args.stride
82 num_pieces = args.num_pieces
83 min_length = args.min_length
84 output_dir = args.output_dir
85
86 if not os.path.exists(output_dir):
87 os.mkdir(output_dir)
88
89 if raw:
90 print('building files')
91 build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces,
92 full_tokenizer=full_tokenizer, min_length=min_length)
93 print('files built')
94
95 if not args.pretrained_model:
96 print('you need to specify a trained model.')

Callers 1

eval.pyFile · 0.70

Calls 3

build_filesFunction · 0.70
from_pretrainedMethod · 0.45
convert_tokens_to_idsMethod · 0.45

Tested by

no test coverage detected