| 22 | self.args = args |
| 23 | |
| 24 | def update(self, minibatches, opt, sch): |
| 25 | all_x = torch.cat([data[0].cuda().float() for data in minibatches]) |
| 26 | all_y = torch.cat([data[1].cuda().long() for data in minibatches]) |
| 27 | all_z = self.featurizer(all_x) |
| 28 | |
| 29 | disc_input = all_z |
| 30 | disc_input = Adver_network.ReverseLayerF.apply( |
| 31 | disc_input, self.args.alpha) |
| 32 | disc_out = self.discriminator(disc_input) |
| 33 | disc_labels = torch.cat([ |
| 34 | torch.full((data[0].shape[0], ), i, |
| 35 | dtype=torch.int64, device='cuda') |
| 36 | for i, data in enumerate(minibatches) |
| 37 | ]) |
| 38 | |
| 39 | disc_loss = F.cross_entropy(disc_out, disc_labels) |
| 40 | all_preds = self.classifier(all_z) |
| 41 | classifier_loss = F.cross_entropy(all_preds, all_y) |
| 42 | loss = classifier_loss+disc_loss |
| 43 | opt.zero_grad() |
| 44 | loss.backward() |
| 45 | opt.step() |
| 46 | if sch: |
| 47 | sch.step() |
| 48 | return {'total': loss.item(), 'class': classifier_loss.item(), 'dis': disc_loss.item()} |
| 49 | |
| 50 | def predict(self, x): |
| 51 | return self.classifier(self.featurizer(x)) |