MCPcopy
hub / github.com/yunjey/stargan / train

Method train

solver.py:182–339  ·  view source on GitHub ↗

Train StarGAN within a single dataset.

(self)

Source from the content-addressed store, hash-verified

180 return F.cross_entropy(logit, target)
181
182 def train(self):
183 """Train StarGAN within a single dataset."""
184 # Set data loader.
185 if self.dataset == 'CelebA':
186 data_loader = self.celeba_loader
187 elif self.dataset == 'RaFD':
188 data_loader = self.rafd_loader
189
190 # Fetch fixed inputs for debugging.
191 data_iter = iter(data_loader)
192 x_fixed, c_org = next(data_iter)
193 x_fixed = x_fixed.to(self.device)
194 c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
195
196 # Learning rate cache for decaying.
197 g_lr = self.g_lr
198 d_lr = self.d_lr
199
200 # Start training from scratch or resume training.
201 start_iters = 0
202 if self.resume_iters:
203 start_iters = self.resume_iters
204 self.restore_model(self.resume_iters)
205
206 # Start training.
207 print('Start training...')
208 start_time = time.time()
209 for i in range(start_iters, self.num_iters):
210
211 # =================================================================================== #
212 # 1. Preprocess input data #
213 # =================================================================================== #
214
215 # Fetch real images and labels.
216 try:
217 x_real, label_org = next(data_iter)
218 except:
219 data_iter = iter(data_loader)
220 x_real, label_org = next(data_iter)
221
222 # Generate target domain labels randomly.
223 rand_idx = torch.randperm(label_org.size(0))
224 label_trg = label_org[rand_idx]
225
226 if self.dataset == 'CelebA':
227 c_org = label_org.clone()
228 c_trg = label_trg.clone()
229 elif self.dataset == 'RaFD':
230 c_org = self.label2onehot(label_org, self.c_dim)
231 c_trg = self.label2onehot(label_trg, self.c_dim)
232
233 x_real = x_real.to(self.device) # Input images.
234 c_org = c_org.to(self.device) # Original domain labels.
235 c_trg = c_trg.to(self.device) # Target domain labels.
236 label_org = label_org.to(self.device) # Labels for computing classification loss.
237 label_trg = label_trg.to(self.device) # Labels for computing classification loss.
238
239 # =================================================================================== #

Callers 1

mainFunction · 0.95

Calls 9

create_labelsMethod · 0.95
restore_modelMethod · 0.95
label2onehotMethod · 0.95
classification_lossMethod · 0.95
gradient_penaltyMethod · 0.95
reset_gradMethod · 0.95
denormMethod · 0.95
update_lrMethod · 0.95
scalar_summaryMethod · 0.80

Tested by

no test coverage detected