(self, minibatches, opt, sch)
| 15 | self.num_classes = args.num_classes |
| 16 | |
| 17 | def update(self, minibatches, opt, sch): |
| 18 | all_x = torch.cat([data[0].cuda().float() for data in minibatches]) |
| 19 | all_y = torch.cat([data[1].cuda().long() for data in minibatches]) |
| 20 | all_o = torch.nn.functional.one_hot(all_y, self.num_classes) |
| 21 | all_f = self.featurizer(all_x) |
| 22 | all_p = self.classifier(all_f) |
| 23 | |
| 24 | # Equation (1): compute gradients with respect to representation |
| 25 | all_g = autograd.grad((all_p * all_o).sum(), all_f)[0] |
| 26 | |
| 27 | # Equation (2): compute top-gradient-percentile mask |
| 28 | percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1) |
| 29 | percentiles = torch.Tensor(percentiles) |
| 30 | percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1)) |
| 31 | mask_f = all_g.lt(percentiles.cuda()).float() |
| 32 | |
| 33 | # Equation (3): mute top-gradient-percentile activations |
| 34 | all_f_muted = all_f * mask_f |
| 35 | |
| 36 | # Equation (4): compute muted predictions |
| 37 | all_p_muted = self.classifier(all_f_muted) |
| 38 | |
| 39 | # Section 3.3: Batch Percentage |
| 40 | all_s = F.softmax(all_p, dim=1) |
| 41 | all_s_muted = F.softmax(all_p_muted, dim=1) |
| 42 | changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1) |
| 43 | percentile = np.percentile(changes.detach().cpu(), self.drop_b) |
| 44 | mask_b = changes.lt(percentile).float().view(-1, 1) |
| 45 | mask = torch.logical_or(mask_f, mask_b).float() |
| 46 | |
| 47 | # Equations (3) and (4) again, this time mutting over examples |
| 48 | all_p_muted_again = self.classifier(all_f * mask) |
| 49 | |
| 50 | # Equation (5): update |
| 51 | loss = F.cross_entropy(all_p_muted_again, all_y) |
| 52 | opt.zero_grad() |
| 53 | loss.backward() |
| 54 | opt.step() |
| 55 | if sch: |
| 56 | sch.step() |
| 57 | |
| 58 | return {'class': loss.item()} |
no test coverage detected