Translate images using StarGAN trained on a single dataset.
(self)
| 521 | print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) |
| 522 | |
| 523 | def test(self): |
| 524 | """Translate images using StarGAN trained on a single dataset.""" |
| 525 | # Load the trained generator. |
| 526 | self.restore_model(self.test_iters) |
| 527 | |
| 528 | # Set data loader. |
| 529 | if self.dataset == 'CelebA': |
| 530 | data_loader = self.celeba_loader |
| 531 | elif self.dataset == 'RaFD': |
| 532 | data_loader = self.rafd_loader |
| 533 | |
| 534 | with torch.no_grad(): |
| 535 | for i, (x_real, c_org) in enumerate(data_loader): |
| 536 | |
| 537 | # Prepare input images and target domain labels. |
| 538 | x_real = x_real.to(self.device) |
| 539 | c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) |
| 540 | |
| 541 | # Translate images. |
| 542 | x_fake_list = [x_real] |
| 543 | for c_trg in c_trg_list: |
| 544 | x_fake_list.append(self.G(x_real, c_trg)) |
| 545 | |
| 546 | # Save the translated images. |
| 547 | x_concat = torch.cat(x_fake_list, dim=3) |
| 548 | result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) |
| 549 | save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) |
| 550 | print('Saved real and fake images into {}...'.format(result_path)) |
| 551 | |
| 552 | def test_multi(self): |
| 553 | """Translate images using StarGAN trained on multiple datasets.""" |
no test coverage detected