(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
device='cpu')
| 69 | |
| 70 | |
| 71 | def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, |
| 72 | device='cpu'): |
| 73 | context = torch.tensor(context, dtype=torch.long, device=device) |
| 74 | context = context.unsqueeze(0) |
| 75 | generated = context |
| 76 | with torch.no_grad(): |
| 77 | for _ in trange(length): |
| 78 | inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} |
| 79 | outputs = model( |
| 80 | **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) |
| 81 | next_token_logits = outputs[0][0, -1, :] |
| 82 | for id in set(generated): |
| 83 | next_token_logits[id] /= repitition_penalty |
| 84 | next_token_logits = next_token_logits / temperature |
| 85 | next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') |
| 86 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) |
| 87 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) |
| 88 | generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) |
| 89 | return generated.tolist()[0] |
| 90 | |
| 91 | |
| 92 | def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'): |
no test coverage detected