MCPcopy
hub / github.com/jindongwang/transferlearning / update

Method update

code/DeepDG/alg/algs/MMD.py:38–65  ·  view source on GitHub ↗
(self, minibatches, opt, sch)

Source from the content-addressed store, hash-verified

36 return Kxx + Kyy - 2 * Kxy
37
38 def update(self, minibatches, opt, sch):
39 objective = 0
40 penalty = 0
41 nmb = len(minibatches)
42
43 features = [self.featurizer(
44 data[0].cuda().float()) for data in minibatches]
45 classifs = [self.classifier(fi) for fi in features]
46 targets = [data[1].cuda().long() for data in minibatches]
47
48 for i in range(nmb):
49 objective += F.cross_entropy(classifs[i], targets[i])
50 for j in range(i + 1, nmb):
51 penalty += self.mmd(features[i], features[j])
52
53 objective /= nmb
54 if nmb > 1:
55 penalty /= (nmb * (nmb - 1) / 2)
56
57 opt.zero_grad()
58 (objective + (self.args.mmd_gamma*penalty)).backward()
59 opt.step()
60 if sch:
61 sch.step()
62 if torch.is_tensor(penalty):
63 penalty = penalty.item()
64
65 return {'class': objective.item(), 'mmd': penalty, 'total': (objective.item() + (self.args.mmd_gamma*penalty))}

Callers

nothing calls this directly

Calls 3

mmdMethod · 0.95
stepMethod · 0.80
backwardMethod · 0.45

Tested by

no test coverage detected