Train StarGAN with multiple datasets.
(self)
| 339 | print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) |
| 340 | |
| 341 | def train_multi(self): |
| 342 | """Train StarGAN with multiple datasets.""" |
| 343 | # Data iterators. |
| 344 | celeba_iter = iter(self.celeba_loader) |
| 345 | rafd_iter = iter(self.rafd_loader) |
| 346 | |
| 347 | # Fetch fixed inputs for debugging. |
| 348 | x_fixed, c_org = next(celeba_iter) |
| 349 | x_fixed = x_fixed.to(self.device) |
| 350 | c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) |
| 351 | c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') |
| 352 | zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. |
| 353 | zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. |
| 354 | mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0]. |
| 355 | mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1]. |
| 356 | |
| 357 | # Learning rate cache for decaying. |
| 358 | g_lr = self.g_lr |
| 359 | d_lr = self.d_lr |
| 360 | |
| 361 | # Start training from scratch or resume training. |
| 362 | start_iters = 0 |
| 363 | if self.resume_iters: |
| 364 | start_iters = self.resume_iters |
| 365 | self.restore_model(self.resume_iters) |
| 366 | |
| 367 | # Start training. |
| 368 | print('Start training...') |
| 369 | start_time = time.time() |
| 370 | for i in range(start_iters, self.num_iters): |
| 371 | for dataset in ['CelebA', 'RaFD']: |
| 372 | |
| 373 | # =================================================================================== # |
| 374 | # 1. Preprocess input data # |
| 375 | # =================================================================================== # |
| 376 | |
| 377 | # Fetch real images and labels. |
| 378 | data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter |
| 379 | |
| 380 | try: |
| 381 | x_real, label_org = next(data_iter) |
| 382 | except: |
| 383 | if dataset == 'CelebA': |
| 384 | celeba_iter = iter(self.celeba_loader) |
| 385 | x_real, label_org = next(celeba_iter) |
| 386 | elif dataset == 'RaFD': |
| 387 | rafd_iter = iter(self.rafd_loader) |
| 388 | x_real, label_org = next(rafd_iter) |
| 389 | |
| 390 | # Generate target domain labels randomly. |
| 391 | rand_idx = torch.randperm(label_org.size(0)) |
| 392 | label_trg = label_org[rand_idx] |
| 393 | |
| 394 | if dataset == 'CelebA': |
| 395 | c_org = label_org.clone() |
| 396 | c_trg = label_trg.clone() |
| 397 | zero = torch.zeros(x_real.size(0), self.c2_dim) |
| 398 | mask = self.label2onehot(torch.zeros(x_real.size(0)), 2) |
no test coverage detected