(
run_dir = '.', # Output directory.
training_set_kwargs = {}, # Options for training set.
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
G_kwargs = {}, # Options for generator network.
D_kwargs = {}, # Options for discriminator network.
G_opt_kwargs = {}, # Options for generator optimizer.
D_opt_kwargs = {}, # Options for discriminator optimizer.
augment_kwargs = None, # Options for augmentation pipeline. None = disable.
loss_kwargs = {}, # Options for loss function.
metrics = [], # Metrics to evaluate during training.
random_seed = 0, # Global random seed.
num_gpus = 1, # Number of GPUs participating in the training.
rank = 0, # Rank of the current process in [0, num_gpus[.
batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
batch_gpu = 4, # Number of samples processed at a time by one GPU.
ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup.
G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization.
D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
augment_p = 0, # Initial value of augmentation probability.
ada_target = None, # ADA target value. None = fixed p.
ada_interval = 4, # How often to perform ADA adjustment?
ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
total_kimg = 25000, # Total length of the training, measured in thousands of real images.
kimg_per_tick = 4, # Progress snapshot interval.
image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
resume_pkl = None, # Network pickle to resume training from.
resume_kimg = 0, # First kimg to report when resuming training.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
)
| 89 | #---------------------------------------------------------------------------- |
| 90 | |
| 91 | def training_loop( |
| 92 | run_dir = '.', # Output directory. |
| 93 | training_set_kwargs = {}, # Options for training set. |
| 94 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. |
| 95 | G_kwargs = {}, # Options for generator network. |
| 96 | D_kwargs = {}, # Options for discriminator network. |
| 97 | G_opt_kwargs = {}, # Options for generator optimizer. |
| 98 | D_opt_kwargs = {}, # Options for discriminator optimizer. |
| 99 | augment_kwargs = None, # Options for augmentation pipeline. None = disable. |
| 100 | loss_kwargs = {}, # Options for loss function. |
| 101 | metrics = [], # Metrics to evaluate during training. |
| 102 | random_seed = 0, # Global random seed. |
| 103 | num_gpus = 1, # Number of GPUs participating in the training. |
| 104 | rank = 0, # Rank of the current process in [0, num_gpus[. |
| 105 | batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. |
| 106 | batch_gpu = 4, # Number of samples processed at a time by one GPU. |
| 107 | ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. |
| 108 | ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup. |
| 109 | G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization. |
| 110 | D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. |
| 111 | augment_p = 0, # Initial value of augmentation probability. |
| 112 | ada_target = None, # ADA target value. None = fixed p. |
| 113 | ada_interval = 4, # How often to perform ADA adjustment? |
| 114 | ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. |
| 115 | total_kimg = 25000, # Total length of the training, measured in thousands of real images. |
| 116 | kimg_per_tick = 4, # Progress snapshot interval. |
| 117 | image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. |
| 118 | network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. |
| 119 | resume_pkl = None, # Network pickle to resume training from. |
| 120 | resume_kimg = 0, # First kimg to report when resuming training. |
| 121 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? |
| 122 | abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. |
| 123 | progress_fn = None, # Callback function for updating training progress. Called for all ranks. |
| 124 | ): |
| 125 | # Initialize. |
| 126 | start_time = time.time() |
| 127 | device = torch.device('cuda', rank) |
| 128 | np.random.seed(random_seed * num_gpus + rank) |
| 129 | torch.manual_seed(random_seed * num_gpus + rank) |
| 130 | torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. |
| 131 | torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy. |
| 132 | torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy. |
| 133 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False # Improves numerical accuracy. |
| 134 | conv2d_gradfix.enabled = True # Improves training speed. # TODO: ENABLE |
| 135 | grid_sample_gradfix.enabled = False # Avoids errors with the augmentation pipe. |
| 136 | |
| 137 | # Load training set. |
| 138 | if rank == 0: |
| 139 | print('Loading training set...') |
| 140 | training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset |
| 141 | training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) |
| 142 | training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) |
| 143 | if rank == 0: |
| 144 | print() |
| 145 | print('Num images: ', len(training_set)) |
| 146 | print('Image shape:', training_set.image_shape) |
| 147 | print('Label shape:', training_set.label_shape) |
| 148 | print() |
nothing calls this directly
no test coverage detected