MCPcopy
hub / github.com/yunjey/stargan / train_multi

Method train_multi

solver.py:341–521  ·  view source on GitHub ↗

Train StarGAN with multiple datasets.

(self)

Source from the content-addressed store, hash-verified

339 print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))
340
341 def train_multi(self):
342 """Train StarGAN with multiple datasets."""
343 # Data iterators.
344 celeba_iter = iter(self.celeba_loader)
345 rafd_iter = iter(self.rafd_loader)
346
347 # Fetch fixed inputs for debugging.
348 x_fixed, c_org = next(celeba_iter)
349 x_fixed = x_fixed.to(self.device)
350 c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
351 c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
352 zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA.
353 zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD.
354 mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0].
355 mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1].
356
357 # Learning rate cache for decaying.
358 g_lr = self.g_lr
359 d_lr = self.d_lr
360
361 # Start training from scratch or resume training.
362 start_iters = 0
363 if self.resume_iters:
364 start_iters = self.resume_iters
365 self.restore_model(self.resume_iters)
366
367 # Start training.
368 print('Start training...')
369 start_time = time.time()
370 for i in range(start_iters, self.num_iters):
371 for dataset in ['CelebA', 'RaFD']:
372
373 # =================================================================================== #
374 # 1. Preprocess input data #
375 # =================================================================================== #
376
377 # Fetch real images and labels.
378 data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
379
380 try:
381 x_real, label_org = next(data_iter)
382 except:
383 if dataset == 'CelebA':
384 celeba_iter = iter(self.celeba_loader)
385 x_real, label_org = next(celeba_iter)
386 elif dataset == 'RaFD':
387 rafd_iter = iter(self.rafd_loader)
388 x_real, label_org = next(rafd_iter)
389
390 # Generate target domain labels randomly.
391 rand_idx = torch.randperm(label_org.size(0))
392 label_trg = label_org[rand_idx]
393
394 if dataset == 'CelebA':
395 c_org = label_org.clone()
396 c_trg = label_trg.clone()
397 zero = torch.zeros(x_real.size(0), self.c2_dim)
398 mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)

Callers 1

mainFunction · 0.95

Calls 9

create_labelsMethod · 0.95
label2onehotMethod · 0.95
restore_modelMethod · 0.95
classification_lossMethod · 0.95
gradient_penaltyMethod · 0.95
reset_gradMethod · 0.95
denormMethod · 0.95
update_lrMethod · 0.95
scalar_summaryMethod · 0.80

Tested by

no test coverage detected