(trainer, split, max_batches=None)
| 143 | |
| 144 | # helper function for the evaluation of a model |
| 145 | def eval_split(trainer, split, max_batches=None): |
| 146 | dataset = {'train':train_dataset, 'test':test_dataset}[split] |
| 147 | ndigit = config.data.ndigit |
| 148 | results = [] |
| 149 | mistakes_printed_already = 0 |
| 150 | factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device) |
| 151 | loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False) |
| 152 | for b, (x, y) in enumerate(loader): |
| 153 | x = x.to(trainer.device) |
| 154 | # isolate the first two digits of the input sequence alone |
| 155 | d1d2 = x[:, :ndigit*2] |
| 156 | # let the model sample the rest of the sequence |
| 157 | d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling |
| 158 | # isolate the last digit of the sampled sequence |
| 159 | d3 = d1d2d3[:, -(ndigit+1):] |
| 160 | d3 = d3.flip(1) # reverse the digits to their "normal" order |
| 161 | # decode the integers from individual digits |
| 162 | d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1) |
| 163 | d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1) |
| 164 | d3i_pred = (d3 * factors).sum(1) |
| 165 | d3i_gt = d1i + d2i # manually calculate the ground truth |
| 166 | # evaluate the correctness of the results in this batch |
| 167 | correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha |
| 168 | for i in range(x.size(0)): |
| 169 | results.append(int(correct[i])) |
| 170 | if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense |
| 171 | mistakes_printed_already += 1 |
| 172 | print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i])) |
| 173 | if max_batches is not None and b+1 >= max_batches: |
| 174 | break |
| 175 | rt = torch.tensor(results, dtype=torch.float) |
| 176 | print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean())) |
| 177 | return rt.sum() |
| 178 | |
| 179 | # iteration callback |
| 180 | top_score = 0 |
no test coverage detected