| 36 | return Kxx + Kyy - 2 * Kxy |
| 37 | |
| 38 | def update(self, minibatches, opt, sch): |
| 39 | objective = 0 |
| 40 | penalty = 0 |
| 41 | nmb = len(minibatches) |
| 42 | |
| 43 | features = [self.featurizer( |
| 44 | data[0].cuda().float()) for data in minibatches] |
| 45 | classifs = [self.classifier(fi) for fi in features] |
| 46 | targets = [data[1].cuda().long() for data in minibatches] |
| 47 | |
| 48 | for i in range(nmb): |
| 49 | objective += F.cross_entropy(classifs[i], targets[i]) |
| 50 | for j in range(i + 1, nmb): |
| 51 | penalty += self.mmd(features[i], features[j]) |
| 52 | |
| 53 | objective /= nmb |
| 54 | if nmb > 1: |
| 55 | penalty /= (nmb * (nmb - 1) / 2) |
| 56 | |
| 57 | opt.zero_grad() |
| 58 | (objective + (self.args.mmd_gamma*penalty)).backward() |
| 59 | opt.step() |
| 60 | if sch: |
| 61 | sch.step() |
| 62 | if torch.is_tensor(penalty): |
| 63 | penalty = penalty.item() |
| 64 | |
| 65 | return {'class': objective.item(), 'mmd': penalty, 'total': (objective.item() + (self.args.mmd_gamma*penalty))} |