()
| 36 | |
| 37 | |
| 38 | def main(): |
| 39 | parser = argparse.ArgumentParser() |
| 40 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') |
| 41 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, |
| 42 | help='选择模型参数') |
| 43 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') |
| 44 | parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料') |
| 45 | parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False, |
| 46 | help='tokenized语料存放位置') |
| 47 | parser.add_argument('--raw', action='store_true', help='是否先做tokenize') |
| 48 | parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环') |
| 49 | parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size') |
| 50 | parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') |
| 51 | parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数') |
| 52 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') |
| 53 | parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长') |
| 54 | parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') |
| 55 | parser.add_argument('--fp16', action='store_true', help='混合精度') |
| 56 | parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False) |
| 57 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) |
| 58 | parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') |
| 59 | parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径') |
| 60 | parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径') |
| 61 | parser.add_argument('--segment', action='store_true', help='中文以词为单位') |
| 62 | |
| 63 | args = parser.parse_args() |
| 64 | print('args:\n' + args.__repr__()) |
| 65 | |
| 66 | if args.segment: |
| 67 | from tokenizations import tokenization_bert_word_level as tokenization_bert |
| 68 | else: |
| 69 | from tokenizations import tokenization_bert |
| 70 | |
| 71 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 |
| 72 | model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) |
| 73 | print('config:\n' + model_config.to_json_string()) |
| 74 | |
| 75 | n_ctx = model_config.n_ctx |
| 76 | full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) |
| 77 | full_tokenizer.max_len = 999999 |
| 78 | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 79 | print('using device:', device) |
| 80 | |
| 81 | raw_data_path = args.raw_data_path |
| 82 | tokenized_data_path = args.tokenized_data_path |
| 83 | raw = args.raw # 选择是否从零开始构建数据集 |
| 84 | epochs = args.epochs |
| 85 | batch_size = args.batch_size |
| 86 | lr = args.lr |
| 87 | warmup_steps = args.warmup_steps |
| 88 | log_step = args.log_step |
| 89 | stride = args.stride |
| 90 | gradient_accumulation = args.gradient_accumulation |
| 91 | fp16 = args.fp16 # 不支持半精度的显卡请勿打开 |
| 92 | fp16_opt_level = args.fp16_opt_level |
| 93 | max_grad_norm = args.max_grad_norm |
| 94 | num_pieces = args.num_pieces |
| 95 | output_dir = args.output_dir |
no test coverage detected