| 232 | return hidden_states, presents, aux_loss |
| 233 | |
| 234 | class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): |
| 235 | config_class = MiniMindConfig |
| 236 | _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
| 237 | def __init__(self, config: MiniMindConfig = None): |
| 238 | self.config = config or MiniMindConfig() |
| 239 | super().__init__(self.config) |
| 240 | self.model = MiniMindModel(self.config) |
| 241 | self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) |
| 242 | if self.config.tie_word_embeddings: self.model.embed_tokens.weight = self.lm_head.weight |
| 243 | self.post_init() |
| 244 | |
| 245 | def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs): |
| 246 | hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs) |
| 247 | slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| 248 | logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| 249 | loss = None |
| 250 | if labels is not None: |
| 251 | x, y = logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous() |
| 252 | loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1), ignore_index=-100) |
| 253 | return MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states) |
| 254 | |
| 255 | # https://github.com/jingyaogong/minimind/discussions/611 |
| 256 | @torch.inference_mode() |
| 257 | def generate(self, inputs=None, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=2, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs): |
| 258 | input_ids = kwargs.pop("input_ids", inputs).repeat(num_return_sequences, 1) |
| 259 | attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None |
| 260 | past_key_values = kwargs.pop("past_key_values", None) |
| 261 | finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device) |
| 262 | if streamer: streamer.put(input_ids.cpu()) |
| 263 | for _ in range(max_new_tokens): |
| 264 | past_len = past_key_values[0][0].shape[1] if past_key_values else 0 |
| 265 | outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs) |
| 266 | attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None |
| 267 | logits = outputs.logits[:, -1, :] / temperature |
| 268 | if repetition_penalty != 1.0: |
| 269 | for i in range(input_ids.shape[0]): |
| 270 | seen = torch.unique(input_ids[i]); score = logits[i, seen]; logits[i, seen] = torch.where(score > 0, score / repetition_penalty, score * repetition_penalty) |
| 271 | if top_k > 0: |
| 272 | logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf') |
| 273 | if top_p < 1.0: |
| 274 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| 275 | mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p |
| 276 | mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0 |
| 277 | logits[mask.scatter(1, sorted_indices, mask)] = -float('inf') |
| 278 | next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True) |
| 279 | if eos_token_id is not None: next_token = torch.where(finished.unsqueeze(-1), next_token.new_full((next_token.shape[0], 1), eos_token_id), next_token) |
| 280 | input_ids = torch.cat([input_ids, next_token], dim=-1) |
| 281 | past_key_values = outputs.past_key_values if use_cache else None |
| 282 | if streamer: streamer.put(next_token.cpu()) |
| 283 | if eos_token_id is not None: |
| 284 | finished |= next_token.squeeze(-1).eq(eos_token_id) |
| 285 | if finished.all(): break |
| 286 | if streamer: streamer.end() |
| 287 | if kwargs.get("return_kv"): return {'generated_ids': input_ids, 'past_kv': past_key_values} |
| 288 | return input_ids |
no outgoing calls
no test coverage detected