MCPcopy
hub / github.com/Morizeyao/GPT2-Chinese / main

Function main

generate.py:124–218  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

122
123
124def main():
125 parser = argparse.ArgumentParser()
126 parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备')
127 parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度')
128 parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size')
129 parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本')
130 parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度')
131 parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一')
132 parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
133 parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False,
134 help='模型参数')
135 parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径')
136 parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径')
137 parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头')
138 parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词')
139 parser.add_argument('--segment', action='store_true', help='中文以词为单位')
140 parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本')
141 parser.add_argument('--save_samples', action='store_true', help='保存产生的样本')
142 parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径")
143 parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False)
144
145 args = parser.parse_args()
146 print('args:\n' + args.__repr__())
147
148 if args.segment:
149 from tokenizations import tokenization_bert_word_level as tokenization_bert
150 else:
151 from tokenizations import tokenization_bert
152
153 os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡
154 length = args.length
155 batch_size = args.batch_size
156 nsamples = args.nsamples
157 temperature = args.temperature
158 topk = args.topk
159 topp = args.topp
160 repetition_penalty = args.repetition_penalty
161
162 device = "cuda" if torch.cuda.is_available() else "cpu"
163
164 tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path)
165 model = GPT2LMHeadModel.from_pretrained(args.model_path)
166 model.to(device)
167 model.eval()
168
169 n_ctx = model.config.n_ctx
170
171 if length == -1:
172 length = model.config.n_ctx
173 if args.save_samples:
174 if not os.path.exists(args.save_samples_path):
175 os.makedirs(args.save_samples_path)
176 samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8')
177 while True:
178 raw_text = args.prefix
179 context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
180 generated = 0
181 for _ in range(nsamples // batch_size):

Callers 1

generate.pyFile · 0.70

Calls 5

generateFunction · 0.85
is_wordFunction · 0.70
from_pretrainedMethod · 0.45
convert_tokens_to_idsMethod · 0.45
tokenizeMethod · 0.45

Tested by

no test coverage detected