Train StarGAN within a single dataset.
(self)
| 180 | return F.cross_entropy(logit, target) |
| 181 | |
| 182 | def train(self): |
| 183 | """Train StarGAN within a single dataset.""" |
| 184 | # Set data loader. |
| 185 | if self.dataset == 'CelebA': |
| 186 | data_loader = self.celeba_loader |
| 187 | elif self.dataset == 'RaFD': |
| 188 | data_loader = self.rafd_loader |
| 189 | |
| 190 | # Fetch fixed inputs for debugging. |
| 191 | data_iter = iter(data_loader) |
| 192 | x_fixed, c_org = next(data_iter) |
| 193 | x_fixed = x_fixed.to(self.device) |
| 194 | c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) |
| 195 | |
| 196 | # Learning rate cache for decaying. |
| 197 | g_lr = self.g_lr |
| 198 | d_lr = self.d_lr |
| 199 | |
| 200 | # Start training from scratch or resume training. |
| 201 | start_iters = 0 |
| 202 | if self.resume_iters: |
| 203 | start_iters = self.resume_iters |
| 204 | self.restore_model(self.resume_iters) |
| 205 | |
| 206 | # Start training. |
| 207 | print('Start training...') |
| 208 | start_time = time.time() |
| 209 | for i in range(start_iters, self.num_iters): |
| 210 | |
| 211 | # =================================================================================== # |
| 212 | # 1. Preprocess input data # |
| 213 | # =================================================================================== # |
| 214 | |
| 215 | # Fetch real images and labels. |
| 216 | try: |
| 217 | x_real, label_org = next(data_iter) |
| 218 | except: |
| 219 | data_iter = iter(data_loader) |
| 220 | x_real, label_org = next(data_iter) |
| 221 | |
| 222 | # Generate target domain labels randomly. |
| 223 | rand_idx = torch.randperm(label_org.size(0)) |
| 224 | label_trg = label_org[rand_idx] |
| 225 | |
| 226 | if self.dataset == 'CelebA': |
| 227 | c_org = label_org.clone() |
| 228 | c_trg = label_trg.clone() |
| 229 | elif self.dataset == 'RaFD': |
| 230 | c_org = self.label2onehot(label_org, self.c_dim) |
| 231 | c_trg = self.label2onehot(label_trg, self.c_dim) |
| 232 | |
| 233 | x_real = x_real.to(self.device) # Input images. |
| 234 | c_org = c_org.to(self.device) # Original domain labels. |
| 235 | c_trg = c_trg.to(self.device) # Target domain labels. |
| 236 | label_org = label_org.to(self.device) # Labels for computing classification loss. |
| 237 | label_trg = label_trg.to(self.device) # Labels for computing classification loss. |
| 238 | |
| 239 | # =================================================================================== # |
no test coverage detected