MCPcopy
hub / github.com/karpathy/nanochat / ModelWrapper

Class ModelWrapper

scripts/base_eval.py:45–64  ·  view source on GitHub ↗

Lightweight wrapper to give HuggingFace models a nanochat-compatible interface.

Source from the content-addressed store, hash-verified

43# HuggingFace loading utilities
44
45class ModelWrapper:
46 """Lightweight wrapper to give HuggingFace models a nanochat-compatible interface."""
47 def __init__(self, model, max_seq_len=None):
48 self.model = model
49 self.max_seq_len = max_seq_len
50
51 def __call__(self, input_ids, targets=None, loss_reduction='mean'):
52 logits = self.model(input_ids).logits
53 if targets is None:
54 return logits
55 loss = torch.nn.functional.cross_entropy(
56 logits.view(-1, logits.size(-1)),
57 targets.view(-1),
58 ignore_index=-1,
59 reduction=loss_reduction
60 )
61 return loss
62
63 def get_device(self):
64 return next(self.model.parameters()).device
65
66
67def load_hf_model(hf_path: str, device):

Callers 1

load_hf_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected