MCPcopy
hub / github.com/Morizeyao/GPT2-Chinese / main

Function main

train_single.py:38–223  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

36
37
38def main():
39 parser = argparse.ArgumentParser()
40 parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡')
41 parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
42 help='选择模型参数')
43 parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库')
44 parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料')
45 parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False,
46 help='tokenized语料存放位置')
47 parser.add_argument('--raw', action='store_true', help='是否先做tokenize')
48 parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环')
49 parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size')
50 parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率')
51 parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数')
52 parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
53 parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长')
54 parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累')
55 parser.add_argument('--fp16', action='store_true', help='混合精度')
56 parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False)
57 parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)
58 parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份')
59 parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径')
60 parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径')
61 parser.add_argument('--segment', action='store_true', help='中文以词为单位')
62
63 args = parser.parse_args()
64 print('args:\n' + args.__repr__())
65
66 if args.segment:
67 from tokenizations import tokenization_bert_word_level as tokenization_bert
68 else:
69 from tokenizations import tokenization_bert
70
71 os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
72 model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config)
73 print('config:\n' + model_config.to_json_string())
74
75 n_ctx = model_config.n_ctx
76 full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
77 full_tokenizer.max_len = 999999
78 device = 'cuda' if torch.cuda.is_available() else 'cpu'
79 print('using device:', device)
80
81 raw_data_path = args.raw_data_path
82 tokenized_data_path = args.tokenized_data_path
83 raw = args.raw # 选择是否从零开始构建数据集
84 epochs = args.epochs
85 batch_size = args.batch_size
86 lr = args.lr
87 warmup_steps = args.warmup_steps
88 log_step = args.log_step
89 stride = args.stride
90 gradient_accumulation = args.gradient_accumulation
91 fp16 = args.fp16 # 不支持半精度的显卡请勿打开
92 fp16_opt_level = args.fp16_opt_level
93 max_grad_norm = args.max_grad_norm
94 num_pieces = args.num_pieces
95 output_dir = args.output_dir

Callers 1

train_single.pyFile · 0.70

Calls 2

build_filesFunction · 0.70
from_pretrainedMethod · 0.45

Tested by

no test coverage detected