MCPcopy
hub / github.com/Morizeyao/GPT2-Chinese / sample_sequence

Function sample_sequence

generate.py:71–89  ·  view source on GitHub ↗
(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
                    device='cpu')

Source from the content-addressed store, hash-verified

69
70
71def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
72 device='cpu'):
73 context = torch.tensor(context, dtype=torch.long, device=device)
74 context = context.unsqueeze(0)
75 generated = context
76 with torch.no_grad():
77 for _ in trange(length):
78 inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
79 outputs = model(
80 **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
81 next_token_logits = outputs[0][0, -1, :]
82 for id in set(generated):
83 next_token_logits[id] /= repitition_penalty
84 next_token_logits = next_token_logits / temperature
85 next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
86 filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
87 next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
88 generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
89 return generated.tolist()[0]
90
91
92def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'):

Callers 1

generateFunction · 0.70

Calls 2

top_k_top_p_filteringFunction · 0.70
convert_tokens_to_idsMethod · 0.45

Tested by

no test coverage detected