MCPcopy
hub / github.com/jindongwang/transferlearning / update

Method update

code/DeepDG/alg/algs/RSC.py:17–58  ·  view source on GitHub ↗
(self, minibatches, opt, sch)

Source from the content-addressed store, hash-verified

15 self.num_classes = args.num_classes
16
17 def update(self, minibatches, opt, sch):
18 all_x = torch.cat([data[0].cuda().float() for data in minibatches])
19 all_y = torch.cat([data[1].cuda().long() for data in minibatches])
20 all_o = torch.nn.functional.one_hot(all_y, self.num_classes)
21 all_f = self.featurizer(all_x)
22 all_p = self.classifier(all_f)
23
24 # Equation (1): compute gradients with respect to representation
25 all_g = autograd.grad((all_p * all_o).sum(), all_f)[0]
26
27 # Equation (2): compute top-gradient-percentile mask
28 percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1)
29 percentiles = torch.Tensor(percentiles)
30 percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1))
31 mask_f = all_g.lt(percentiles.cuda()).float()
32
33 # Equation (3): mute top-gradient-percentile activations
34 all_f_muted = all_f * mask_f
35
36 # Equation (4): compute muted predictions
37 all_p_muted = self.classifier(all_f_muted)
38
39 # Section 3.3: Batch Percentage
40 all_s = F.softmax(all_p, dim=1)
41 all_s_muted = F.softmax(all_p_muted, dim=1)
42 changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1)
43 percentile = np.percentile(changes.detach().cpu(), self.drop_b)
44 mask_b = changes.lt(percentile).float().view(-1, 1)
45 mask = torch.logical_or(mask_f, mask_b).float()
46
47 # Equations (3) and (4) again, this time mutting over examples
48 all_p_muted_again = self.classifier(all_f * mask)
49
50 # Equation (5): update
51 loss = F.cross_entropy(all_p_muted_again, all_y)
52 opt.zero_grad()
53 loss.backward()
54 opt.step()
55 if sch:
56 sch.step()
57
58 return {'class': loss.item()}

Callers 6

train.pyFile · 0.45
load_pretrained_modelFunction · 0.45
load_multilingual_dataFunction · 0.45
load_pretrained_modelFunction · 0.45

Calls 3

sumMethod · 0.80
stepMethod · 0.80
backwardMethod · 0.45

Tested by

no test coverage detected