(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu',
is_fast_pattern=False)
| 112 | |
| 113 | # 通过命令行参数--fast_pattern,指定模式 |
| 114 | def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu', |
| 115 | is_fast_pattern=False): |
| 116 | if is_fast_pattern: |
| 117 | return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p, |
| 118 | device=device) |
| 119 | else: |
| 120 | return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p, |
| 121 | repitition_penalty=repitition_penalty, device=device) |
| 122 | |
| 123 | |
| 124 | def main(): |
no test coverage detected