MCPcopy Index your code
hub / github.com/OpenMOSS/MOSS / sample

Function sample

models_jittor/generation.py:76–136  ·  view source on GitHub ↗
(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k,
           eos_token_id=None, pad_token_id=None)

Source from the content-addressed store, hash-verified

74 return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:]
75
76def sample(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k,
77 eos_token_id=None, pad_token_id=None):
78 model.eval()
79 if eos_token_id is None:
80 eos_token_id = tokenizer.eos_token_id
81 if pad_token_id is None and eos_token_id is not None:
82 pad_token_id = eos_token_id
83 eos_token_id_tensor = jt.Var(eos_token_id)
84
85 tokenized = tokenizer(input_str, return_tensors='np')
86 sentence_ids = jt.Var(tokenized['input_ids'])
87 attention_mask = jt.Var(tokenized['attention_mask'])
88 unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1)
89 past_key_values = None
90
91 while True:
92
93 # set input
94 if past_key_values:
95 input_ids = sentence_ids[:, -1].unsqueeze(-1)
96 else:
97 input_ids = sentence_ids
98 outputs = model(input_ids, past_key_values=past_key_values,
99 attention_mask=attention_mask)
100
101 next_token_logits = outputs['logits'][:, -1, :].float()
102
103 # sample
104 # temperature
105 scores = next_token_logits / temperature
106 # top_k
107 scores = sample_top_k(scores, top_k)
108 # top_p
109 scores = sample_top_p(scores, top_p)
110
111 probs = jt.nn.softmax(scores, dim=-1)
112 next_tokens = jt.multinomial(probs, num_samples=1).squeeze(1)
113 # concat sentence
114 next_tokens = next_tokens * unfinished_sequences + \
115 pad_token_id * (1 - unfinished_sequences)
116
117 # update generated ids, model inputs, and length for next step
118 sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1)
119 past_key_values = outputs['past_key_values']
120 attention_mask = jt.cat(
121 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
122
123 # if eos_token was found in one sentence, set sentence to finished
124 next_tokens.repeat(eos_token_id_tensor.shape[0], 1)
125 unfinished_sequences = unfinished_sequences.mul(
126 next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \
127 .not_equal(eos_token_id_tensor.unsqueeze(1)) \
128 .prod(dim=0)
129 )
130
131 jt.sync_all()
132
133 if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len:

Callers 1

generateFunction · 0.85

Calls 2

sample_top_kFunction · 0.85
sample_top_pFunction · 0.85

Tested by

no test coverage detected