| 69 | return mean_diff + cova_diff |
| 70 | |
| 71 | def update(self, minibatches, opt, sch): |
| 72 | all_x = torch.cat([data[0].cuda().float() for data in minibatches]) |
| 73 | all_y = torch.cat([data[1].cuda().long() for data in minibatches]) |
| 74 | with torch.no_grad(): |
| 75 | all_x1 = torch.angle(torch.fft.fftn(all_x, dim=(2, 3))) |
| 76 | tfea = self.teab(self.teaf(all_x1)).detach() |
| 77 | |
| 78 | all_z = self.bottleneck(self.featurizer(all_x)) |
| 79 | loss1 = F.cross_entropy(self.classifier(all_z), all_y) |
| 80 | |
| 81 | loss2 = F.mse_loss(all_z[:, :self.tfbd], tfea)*self.args.alpha |
| 82 | if self.args.disttype == '2-norm': |
| 83 | loss3 = -F.mse_loss(all_z[:, :self.tfbd], |
| 84 | all_z[:, self.tfbd:])*self.args.beta |
| 85 | elif self.args.disttype == 'norm-2-norm': |
| 86 | loss3 = -F.mse_loss(all_z[:, :self.tfbd]/torch.norm(all_z[:, :self.tfbd], dim=1, keepdim=True), |
| 87 | all_z[:, self.tfbd:]/torch.norm(all_z[:, self.tfbd:], dim=1, keepdim=True))*self.args.beta |
| 88 | elif self.args.disttype == 'norm-1-norm': |
| 89 | loss3 = -F.l1_loss(all_z[:, :self.tfbd]/torch.norm(all_z[:, :self.tfbd], dim=1, keepdim=True), |
| 90 | all_z[:, self.tfbd:]/torch.norm(all_z[:, self.tfbd:], dim=1, keepdim=True))*self.args.beta |
| 91 | elif self.args.disttype == 'cos': |
| 92 | loss3 = torch.mean(F.cosine_similarity( |
| 93 | all_z[:, :self.tfbd], all_z[:, self.tfbd:]))*self.args.beta |
| 94 | loss4 = 0 |
| 95 | if len(minibatches) > 1: |
| 96 | for i in range(len(minibatches)-1): |
| 97 | for j in range(i+1, len(minibatches)): |
| 98 | loss4 += self.coral(all_z[i*self.args.batch_size:(i+1)*self.args.batch_size, self.tfbd:], |
| 99 | all_z[j*self.args.batch_size:(j+1)*self.args.batch_size, self.tfbd:]) |
| 100 | loss4 = loss4*2/(len(minibatches) * |
| 101 | (len(minibatches)-1))*self.args.lam |
| 102 | else: |
| 103 | loss4 = self.coral(all_z[:self.args.batch_size//2, self.tfbd:], |
| 104 | all_z[self.args.batch_size//2:, self.tfbd:]) |
| 105 | loss4 = loss4*self.args.lam |
| 106 | |
| 107 | loss = loss1+loss2+loss3+loss4 |
| 108 | opt.zero_grad() |
| 109 | loss.backward() |
| 110 | opt.step() |
| 111 | if sch: |
| 112 | sch.step() |
| 113 | return {'class': loss1.item(), 'dist': (loss2).item(), 'exp': (loss3).item(), 'align': loss4.item(), 'total': loss.item()} |
| 114 | |
| 115 | def predict(self, x): |
| 116 | return self.classifier(self.bottleneck(self.featurizer(x))) |