MCPcopy
hub / github.com/SizheAn/PanoHead / training_loop

Function training_loop

training/training_loop.py:91–493  ·  view source on GitHub ↗
(
    run_dir                 = '.',      # Output directory.
    training_set_kwargs     = {},       # Options for training set.
    data_loader_kwargs      = {},       # Options for torch.utils.data.DataLoader.
    G_kwargs                = {},       # Options for generator network.
    D_kwargs                = {},       # Options for discriminator network.
    G_opt_kwargs            = {},       # Options for generator optimizer.
    D_opt_kwargs            = {},       # Options for discriminator optimizer.
    augment_kwargs          = None,     # Options for augmentation pipeline. None = disable.
    loss_kwargs             = {},       # Options for loss function.
    metrics                 = [],       # Metrics to evaluate during training.
    random_seed             = 0,        # Global random seed.
    num_gpus                = 1,        # Number of GPUs participating in the training.
    rank                    = 0,        # Rank of the current process in [0, num_gpus[.
    batch_size              = 4,        # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
    batch_gpu               = 4,        # Number of samples processed at a time by one GPU.
    ema_kimg                = 10,       # Half-life of the exponential moving average (EMA) of generator weights.
    ema_rampup              = 0.05,     # EMA ramp-up coefficient. None = no rampup.
    G_reg_interval          = None,     # How often to perform regularization for G? None = disable lazy regularization.
    D_reg_interval          = 16,       # How often to perform regularization for D? None = disable lazy regularization.
    augment_p               = 0,        # Initial value of augmentation probability.
    ada_target              = None,     # ADA target value. None = fixed p.
    ada_interval            = 4,        # How often to perform ADA adjustment?
    ada_kimg                = 500,      # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    kimg_per_tick           = 4,        # Progress snapshot interval.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = disable.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = disable.
    resume_pkl              = None,     # Network pickle to resume training from.
    resume_kimg             = 0,        # First kimg to report when resuming training.
    cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
    abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
    progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
)

Source from the content-addressed store, hash-verified

89#----------------------------------------------------------------------------
90
91def training_loop(
92 run_dir = '.', # Output directory.
93 training_set_kwargs = {}, # Options for training set.
94 data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
95 G_kwargs = {}, # Options for generator network.
96 D_kwargs = {}, # Options for discriminator network.
97 G_opt_kwargs = {}, # Options for generator optimizer.
98 D_opt_kwargs = {}, # Options for discriminator optimizer.
99 augment_kwargs = None, # Options for augmentation pipeline. None = disable.
100 loss_kwargs = {}, # Options for loss function.
101 metrics = [], # Metrics to evaluate during training.
102 random_seed = 0, # Global random seed.
103 num_gpus = 1, # Number of GPUs participating in the training.
104 rank = 0, # Rank of the current process in [0, num_gpus[.
105 batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
106 batch_gpu = 4, # Number of samples processed at a time by one GPU.
107 ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
108 ema_rampup = 0.05, # EMA ramp-up coefficient. None = no rampup.
109 G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization.
110 D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
111 augment_p = 0, # Initial value of augmentation probability.
112 ada_target = None, # ADA target value. None = fixed p.
113 ada_interval = 4, # How often to perform ADA adjustment?
114 ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
115 total_kimg = 25000, # Total length of the training, measured in thousands of real images.
116 kimg_per_tick = 4, # Progress snapshot interval.
117 image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
118 network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
119 resume_pkl = None, # Network pickle to resume training from.
120 resume_kimg = 0, # First kimg to report when resuming training.
121 cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
122 abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
123 progress_fn = None, # Callback function for updating training progress. Called for all ranks.
124):
125 # Initialize.
126 start_time = time.time()
127 device = torch.device('cuda', rank)
128 np.random.seed(random_seed * num_gpus + rank)
129 torch.manual_seed(random_seed * num_gpus + rank)
130 torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
131 torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy.
132 torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy.
133 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False # Improves numerical accuracy.
134 conv2d_gradfix.enabled = True # Improves training speed. # TODO: ENABLE
135 grid_sample_gradfix.enabled = False # Avoids errors with the augmentation pipe.
136
137 # Load training set.
138 if rank == 0:
139 print('Loading training set...')
140 training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
141 training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
142 training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
143 if rank == 0:
144 print()
145 print('Num images: ', len(training_set))
146 print('Image shape:', training_set.image_shape)
147 print('Label shape:', training_set.label_shape)
148 print()

Callers

nothing calls this directly

Calls 9

updateMethod · 0.95
as_dictMethod · 0.95
save_image_gridFunction · 0.85
get_label_stdMethod · 0.80
get_labelMethod · 0.80
writeMethod · 0.80
flushMethod · 0.80
accumulate_gradientsMethod · 0.45

Tested by

no test coverage detected