MCPcopy
hub / github.com/jindongwang/transferlearning / update

Method update

code/DeepDG/alg/algs/DANN.py:24–48  ·  view source on GitHub ↗
(self, minibatches, opt, sch)

Source from the content-addressed store, hash-verified

22 self.args = args
23
24 def update(self, minibatches, opt, sch):
25 all_x = torch.cat([data[0].cuda().float() for data in minibatches])
26 all_y = torch.cat([data[1].cuda().long() for data in minibatches])
27 all_z = self.featurizer(all_x)
28
29 disc_input = all_z
30 disc_input = Adver_network.ReverseLayerF.apply(
31 disc_input, self.args.alpha)
32 disc_out = self.discriminator(disc_input)
33 disc_labels = torch.cat([
34 torch.full((data[0].shape[0], ), i,
35 dtype=torch.int64, device='cuda')
36 for i, data in enumerate(minibatches)
37 ])
38
39 disc_loss = F.cross_entropy(disc_out, disc_labels)
40 all_preds = self.classifier(all_z)
41 classifier_loss = F.cross_entropy(all_preds, all_y)
42 loss = classifier_loss+disc_loss
43 opt.zero_grad()
44 loss.backward()
45 opt.step()
46 if sch:
47 sch.step()
48 return {'total': loss.item(), 'class': classifier_loss.item(), 'dis': disc_loss.item()}
49
50 def predict(self, x):
51 return self.classifier(self.featurizer(x))

Callers

nothing calls this directly

Calls 2

stepMethod · 0.80
backwardMethod · 0.45

Tested by

no test coverage detected