()
| 122 | |
| 123 | |
| 124 | def main(): |
| 125 | parser = argparse.ArgumentParser() |
| 126 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备') |
| 127 | parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度') |
| 128 | parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size') |
| 129 | parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本') |
| 130 | parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度') |
| 131 | parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一') |
| 132 | parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率') |
| 133 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, |
| 134 | help='模型参数') |
| 135 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径') |
| 136 | parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径') |
| 137 | parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头') |
| 138 | parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词') |
| 139 | parser.add_argument('--segment', action='store_true', help='中文以词为单位') |
| 140 | parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本') |
| 141 | parser.add_argument('--save_samples', action='store_true', help='保存产生的样本') |
| 142 | parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径") |
| 143 | parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False) |
| 144 | |
| 145 | args = parser.parse_args() |
| 146 | print('args:\n' + args.__repr__()) |
| 147 | |
| 148 | if args.segment: |
| 149 | from tokenizations import tokenization_bert_word_level as tokenization_bert |
| 150 | else: |
| 151 | from tokenizations import tokenization_bert |
| 152 | |
| 153 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 |
| 154 | length = args.length |
| 155 | batch_size = args.batch_size |
| 156 | nsamples = args.nsamples |
| 157 | temperature = args.temperature |
| 158 | topk = args.topk |
| 159 | topp = args.topp |
| 160 | repetition_penalty = args.repetition_penalty |
| 161 | |
| 162 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 163 | |
| 164 | tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) |
| 165 | model = GPT2LMHeadModel.from_pretrained(args.model_path) |
| 166 | model.to(device) |
| 167 | model.eval() |
| 168 | |
| 169 | n_ctx = model.config.n_ctx |
| 170 | |
| 171 | if length == -1: |
| 172 | length = model.config.n_ctx |
| 173 | if args.save_samples: |
| 174 | if not os.path.exists(args.save_samples_path): |
| 175 | os.makedirs(args.save_samples_path) |
| 176 | samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8') |
| 177 | while True: |
| 178 | raw_text = args.prefix |
| 179 | context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text)) |
| 180 | generated = 0 |
| 181 | for _ in range(nsamples // batch_size): |
no test coverage detected