MCPcopy
hub / github.com/karpathy/minGPT / eval_split

Function eval_split

projects/adder/adder.py:145–177  ·  view source on GitHub ↗
(trainer, split, max_batches=None)

Source from the content-addressed store, hash-verified

143
144 # helper function for the evaluation of a model
145 def eval_split(trainer, split, max_batches=None):
146 dataset = {'train':train_dataset, 'test':test_dataset}[split]
147 ndigit = config.data.ndigit
148 results = []
149 mistakes_printed_already = 0
150 factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)
151 loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
152 for b, (x, y) in enumerate(loader):
153 x = x.to(trainer.device)
154 # isolate the first two digits of the input sequence alone
155 d1d2 = x[:, :ndigit*2]
156 # let the model sample the rest of the sequence
157 d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling
158 # isolate the last digit of the sampled sequence
159 d3 = d1d2d3[:, -(ndigit+1):]
160 d3 = d3.flip(1) # reverse the digits to their "normal" order
161 # decode the integers from individual digits
162 d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
163 d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
164 d3i_pred = (d3 * factors).sum(1)
165 d3i_gt = d1i + d2i # manually calculate the ground truth
166 # evaluate the correctness of the results in this batch
167 correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
168 for i in range(x.size(0)):
169 results.append(int(correct[i]))
170 if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense
171 mistakes_printed_already += 1
172 print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i]))
173 if max_batches is not None and b+1 >= max_batches:
174 break
175 rt = torch.tensor(results, dtype=torch.float)
176 print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
177 return rt.sum()
178
179 # iteration callback
180 top_score = 0

Callers 1

batch_end_callbackFunction · 0.85

Calls 1

generateMethod · 0.80

Tested by

no test coverage detected