MCPcopy
hub / github.com/jingyaogong/minimind / MiniMindForCausalLM

Class MiniMindForCausalLM

model/model_minimind.py:234–288  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

232 return hidden_states, presents, aux_loss
233
234class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
235 config_class = MiniMindConfig
236 _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
237 def __init__(self, config: MiniMindConfig = None):
238 self.config = config or MiniMindConfig()
239 super().__init__(self.config)
240 self.model = MiniMindModel(self.config)
241 self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
242 if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight
243 self.post_init()
244
245 def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
246 hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
247 slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
248 logits = self.lm_head(hidden_states[:, slice_indices, :])
249 loss = None
250 if labels is not None:
251 x, y = logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
252 loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1), ignore_index=-100)
253 return MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
254
255 # https://github.com/jingyaogong/minimind/discussions/611
256 @torch.inference_mode()
257 def generate(self, inputs=None, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=2, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs):
258 input_ids = kwargs.pop("input_ids", inputs).repeat(num_return_sequences, 1)
259 attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
260 past_key_values = kwargs.pop("past_key_values", None)
261 finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
262 if streamer: streamer.put(input_ids.cpu())
263 for _ in range(max_new_tokens):
264 past_len = past_key_values[0][0].shape[1] if past_key_values else 0
265 outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
266 attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
267 logits = outputs.logits[:, -1, :] / temperature
268 if repetition_penalty != 1.0:
269 for i in range(input_ids.shape[0]):
270 seen = torch.unique(input_ids[i]); score = logits[i, seen]; logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty)
271 if top_k > 0:
272 logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
273 if top_p < 1.0:
274 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
275 mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
276 mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
277 logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
278 next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
279 if eos_token_id is not None: next_token = torch.where(finished.unsqueeze(-1), next_token.new_full((next_token.shape[0], 1), eos_token_id), next_token)
280 input_ids = torch.cat([input_ids, next_token], dim=-1)
281 past_key_values = outputs.past_key_values if use_cache else None
282 if streamer: streamer.put(next_token.cpu())
283 if eos_token_id is not None:
284 finished |= next_token.squeeze(-1).eq(eos_token_id)
285 if finished.all(): break
286 if streamer: streamer.end()
287 if kwargs.get("return_kv"): return {'generated_ids': input_ids, 'past_kv': past_key_values}
288 return input_ids

Callers 6

init_modelFunction · 0.90
convert_merge_base_loraFunction · 0.90
init_modelFunction · 0.90
init_modelFunction · 0.90
init_modelFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected