MCPcopy
hub / github.com/yunjey/stargan / test_multi

Method test_multi

solver.py:552–582  ·  view source on GitHub ↗

Translate images using StarGAN trained on multiple datasets.

(self)

Source from the content-addressed store, hash-verified

550 print('Saved real and fake images into {}...'.format(result_path))
551
552 def test_multi(self):
553 """Translate images using StarGAN trained on multiple datasets."""
554 # Load the trained generator.
555 self.restore_model(self.test_iters)
556
557 with torch.no_grad():
558 for i, (x_real, c_org) in enumerate(self.celeba_loader):
559
560 # Prepare input images and target domain labels.
561 x_real = x_real.to(self.device)
562 c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
563 c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
564 zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
565 zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
566 mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0].
567 mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1].
568
569 # Translate images.
570 x_fake_list = [x_real]
571 for c_celeba in c_celeba_list:
572 c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
573 x_fake_list.append(self.G(x_real, c_trg))
574 for c_rafd in c_rafd_list:
575 c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
576 x_fake_list.append(self.G(x_real, c_trg))
577
578 # Save the translated images.
579 x_concat = torch.cat(x_fake_list, dim=3)
580 result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
581 save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
582 print('Saved real and fake images into {}...'.format(result_path))

Callers 1

mainFunction · 0.95

Calls 4

restore_modelMethod · 0.95
create_labelsMethod · 0.95
label2onehotMethod · 0.95
denormMethod · 0.95

Tested by

no test coverage detected