Lightweight wrapper to give HuggingFace models a nanochat-compatible interface.
| 43 | # HuggingFace loading utilities |
| 44 | |
| 45 | class ModelWrapper: |
| 46 | """Lightweight wrapper to give HuggingFace models a nanochat-compatible interface.""" |
| 47 | def __init__(self, model, max_seq_len=None): |
| 48 | self.model = model |
| 49 | self.max_seq_len = max_seq_len |
| 50 | |
| 51 | def __call__(self, input_ids, targets=None, loss_reduction='mean'): |
| 52 | logits = self.model(input_ids).logits |
| 53 | if targets is None: |
| 54 | return logits |
| 55 | loss = torch.nn.functional.cross_entropy( |
| 56 | logits.view(-1, logits.size(-1)), |
| 57 | targets.view(-1), |
| 58 | ignore_index=-1, |
| 59 | reduction=loss_reduction |
| 60 | ) |
| 61 | return loss |
| 62 | |
| 63 | def get_device(self): |
| 64 | return next(self.model.parameters()).device |
| 65 | |
| 66 | |
| 67 | def load_hf_model(hf_path: str, device): |