()
| 37 | |
| 38 | |
| 39 | def main(): |
| 40 | parser = argparse.ArgumentParser() |
| 41 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') |
| 42 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, |
| 43 | help='选择模型参数') |
| 44 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') |
| 45 | parser.add_argument('--raw_data_path', default='data/eval.json', type=str, required=False, help='原始语料') |
| 46 | parser.add_argument('--tokenized_data_path', default='data/tokenized_eval/', type=str, required=False, |
| 47 | help='tokenized语料存放位置') |
| 48 | parser.add_argument('--raw', action='store_true', help='是否先做tokenize') |
| 49 | parser.add_argument('--batch_size', default=8, type=int, required=False, help='batch size') |
| 50 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次') |
| 51 | parser.add_argument('--stride', default=768, type=int, required=False, help='取数据的窗口步长') |
| 52 | parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') |
| 53 | parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度') |
| 54 | parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型起点路径') |
| 55 | parser.add_argument('--output_dir', default='eval_result/', type=str, required=False, help='结果输出路径') |
| 56 | |
| 57 | args = parser.parse_args() |
| 58 | print('args:\n' + args.__repr__()) |
| 59 | |
| 60 | # if args.no_wordpiece: |
| 61 | # from tokenizations import tokenization_bert_without_wordpiece as tokenization_bert |
| 62 | # else: |
| 63 | from tokenizations import tokenization_bert |
| 64 | |
| 65 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 |
| 66 | |
| 67 | model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) |
| 68 | print('config:\n' + model_config.to_json_string()) |
| 69 | |
| 70 | n_ctx = model_config.n_ctx |
| 71 | full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) |
| 72 | full_tokenizer.max_len = n_ctx |
| 73 | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 74 | print('using device:', device) |
| 75 | |
| 76 | raw_data_path = args.raw_data_path |
| 77 | tokenized_data_path = args.tokenized_data_path |
| 78 | raw = args.raw # 选择是否从零开始构建数据集 |
| 79 | batch_size = args.batch_size |
| 80 | log_step = args.log_step |
| 81 | stride = args.stride |
| 82 | num_pieces = args.num_pieces |
| 83 | min_length = args.min_length |
| 84 | output_dir = args.output_dir |
| 85 | |
| 86 | if not os.path.exists(output_dir): |
| 87 | os.mkdir(output_dir) |
| 88 | |
| 89 | if raw: |
| 90 | print('building files') |
| 91 | build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces, |
| 92 | full_tokenizer=full_tokenizer, min_length=min_length) |
| 93 | print('files built') |
| 94 | |
| 95 | if not args.pretrained_model: |
| 96 | print('you need to specify a trained model.') |
no test coverage detected