MCPcopy
hub / github.com/jindongwang/transferlearning / update

Method update

code/DeepDG/alg/algs/DIFEX.py:71–113  ·  view source on GitHub ↗
(self, minibatches, opt, sch)

Source from the content-addressed store, hash-verified

69 return mean_diff + cova_diff
70
71 def update(self, minibatches, opt, sch):
72 all_x = torch.cat([data[0].cuda().float() for data in minibatches])
73 all_y = torch.cat([data[1].cuda().long() for data in minibatches])
74 with torch.no_grad():
75 all_x1 = torch.angle(torch.fft.fftn(all_x, dim=(2, 3)))
76 tfea = self.teab(self.teaf(all_x1)).detach()
77
78 all_z = self.bottleneck(self.featurizer(all_x))
79 loss1 = F.cross_entropy(self.classifier(all_z), all_y)
80
81 loss2 = F.mse_loss(all_z[:, :self.tfbd], tfea)*self.args.alpha
82 if self.args.disttype == '2-norm':
83 loss3 = -F.mse_loss(all_z[:, :self.tfbd],
84 all_z[:, self.tfbd:])*self.args.beta
85 elif self.args.disttype == 'norm-2-norm':
86 loss3 = -F.mse_loss(all_z[:, :self.tfbd]/torch.norm(all_z[:, :self.tfbd], dim=1, keepdim=True),
87 all_z[:, self.tfbd:]/torch.norm(all_z[:, self.tfbd:], dim=1, keepdim=True))*self.args.beta
88 elif self.args.disttype == 'norm-1-norm':
89 loss3 = -F.l1_loss(all_z[:, :self.tfbd]/torch.norm(all_z[:, :self.tfbd], dim=1, keepdim=True),
90 all_z[:, self.tfbd:]/torch.norm(all_z[:, self.tfbd:], dim=1, keepdim=True))*self.args.beta
91 elif self.args.disttype == 'cos':
92 loss3 = torch.mean(F.cosine_similarity(
93 all_z[:, :self.tfbd], all_z[:, self.tfbd:]))*self.args.beta
94 loss4 = 0
95 if len(minibatches) > 1:
96 for i in range(len(minibatches)-1):
97 for j in range(i+1, len(minibatches)):
98 loss4 += self.coral(all_z[i*self.args.batch_size:(i+1)*self.args.batch_size, self.tfbd:],
99 all_z[j*self.args.batch_size:(j+1)*self.args.batch_size, self.tfbd:])
100 loss4 = loss4*2/(len(minibatches) *
101 (len(minibatches)-1))*self.args.lam
102 else:
103 loss4 = self.coral(all_z[:self.args.batch_size//2, self.tfbd:],
104 all_z[self.args.batch_size//2:, self.tfbd:])
105 loss4 = loss4*self.args.lam
106
107 loss = loss1+loss2+loss3+loss4
108 opt.zero_grad()
109 loss.backward()
110 opt.step()
111 if sch:
112 sch.step()
113 return {'class': loss1.item(), 'dist': (loss2).item(), 'exp': (loss3).item(), 'align': loss4.item(), 'total': loss.item()}
114
115 def predict(self, x):
116 return self.classifier(self.bottleneck(self.featurizer(x)))

Callers

nothing calls this directly

Calls 4

coralMethod · 0.95
stepMethod · 0.80
meanMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected