Translate images using StarGAN trained on multiple datasets.
(self)
| 550 | print('Saved real and fake images into {}...'.format(result_path)) |
| 551 | |
| 552 | def test_multi(self): |
| 553 | """Translate images using StarGAN trained on multiple datasets.""" |
| 554 | # Load the trained generator. |
| 555 | self.restore_model(self.test_iters) |
| 556 | |
| 557 | with torch.no_grad(): |
| 558 | for i, (x_real, c_org) in enumerate(self.celeba_loader): |
| 559 | |
| 560 | # Prepare input images and target domain labels. |
| 561 | x_real = x_real.to(self.device) |
| 562 | c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) |
| 563 | c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') |
| 564 | zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. |
| 565 | zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. |
| 566 | mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0]. |
| 567 | mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1]. |
| 568 | |
| 569 | # Translate images. |
| 570 | x_fake_list = [x_real] |
| 571 | for c_celeba in c_celeba_list: |
| 572 | c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1) |
| 573 | x_fake_list.append(self.G(x_real, c_trg)) |
| 574 | for c_rafd in c_rafd_list: |
| 575 | c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1) |
| 576 | x_fake_list.append(self.G(x_real, c_trg)) |
| 577 | |
| 578 | # Save the translated images. |
| 579 | x_concat = torch.cat(x_fake_list, dim=3) |
| 580 | result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) |
| 581 | save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) |
| 582 | print('Saved real and fake images into {}...'.format(result_path)) |
no test coverage detected