(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k,
eos_token_id=None, pad_token_id=None)
| 74 | return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:] |
| 75 | |
| 76 | def sample(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k, |
| 77 | eos_token_id=None, pad_token_id=None): |
| 78 | model.eval() |
| 79 | if eos_token_id is None: |
| 80 | eos_token_id = tokenizer.eos_token_id |
| 81 | if pad_token_id is None and eos_token_id is not None: |
| 82 | pad_token_id = eos_token_id |
| 83 | eos_token_id_tensor = jt.Var(eos_token_id) |
| 84 | |
| 85 | tokenized = tokenizer(input_str, return_tensors='np') |
| 86 | sentence_ids = jt.Var(tokenized['input_ids']) |
| 87 | attention_mask = jt.Var(tokenized['attention_mask']) |
| 88 | unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1) |
| 89 | past_key_values = None |
| 90 | |
| 91 | while True: |
| 92 | |
| 93 | # set input |
| 94 | if past_key_values: |
| 95 | input_ids = sentence_ids[:, -1].unsqueeze(-1) |
| 96 | else: |
| 97 | input_ids = sentence_ids |
| 98 | outputs = model(input_ids, past_key_values=past_key_values, |
| 99 | attention_mask=attention_mask) |
| 100 | |
| 101 | next_token_logits = outputs['logits'][:, -1, :].float() |
| 102 | |
| 103 | # sample |
| 104 | # temperature |
| 105 | scores = next_token_logits / temperature |
| 106 | # top_k |
| 107 | scores = sample_top_k(scores, top_k) |
| 108 | # top_p |
| 109 | scores = sample_top_p(scores, top_p) |
| 110 | |
| 111 | probs = jt.nn.softmax(scores, dim=-1) |
| 112 | next_tokens = jt.multinomial(probs, num_samples=1).squeeze(1) |
| 113 | # concat sentence |
| 114 | next_tokens = next_tokens * unfinished_sequences + \ |
| 115 | pad_token_id * (1 - unfinished_sequences) |
| 116 | |
| 117 | # update generated ids, model inputs, and length for next step |
| 118 | sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1) |
| 119 | past_key_values = outputs['past_key_values'] |
| 120 | attention_mask = jt.cat( |
| 121 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) |
| 122 | |
| 123 | # if eos_token was found in one sentence, set sentence to finished |
| 124 | next_tokens.repeat(eos_token_id_tensor.shape[0], 1) |
| 125 | unfinished_sequences = unfinished_sequences.mul( |
| 126 | next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \ |
| 127 | .not_equal(eos_token_id_tensor.unsqueeze(1)) \ |
| 128 | .prod(dim=0) |
| 129 | ) |
| 130 | |
| 131 | jt.sync_all() |
| 132 | |
| 133 | if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len: |
no test coverage detected