MCPcopy Index your code
hub / github.com/POSTECH-CVLab/PyTorch-StudioGAN / load_worker

Function load_worker

src/loader.py:39–510  ·  view source on GitHub ↗
(local_rank, cfgs, gpus_per_node, run_name, hdf5_path)

Source from the content-addressed store, hash-verified

37
38
39def load_worker(local_rank, cfgs, gpus_per_node, run_name, hdf5_path):
40 # -----------------------------------------------------------------------------
41 # define default variables for loading ckpt or evaluating the trained GAN model.
42 # -----------------------------------------------------------------------------
43
44 load_train_dataset = cfgs.RUN.train + cfgs.RUN.GAN_train + cfgs.RUN.GAN_test
45 load_eval_dataset = len(cfgs.RUN.eval_metrics) + cfgs.RUN.save_real_images + cfgs.RUN.k_nearest_neighbor + \
46 cfgs.RUN.frequency_analysis + cfgs.RUN.tsne_analysis + cfgs.RUN.intra_class_fid
47 train_sampler, eval_sampler = None, None
48 step, epoch, topk, best_step, best_fid, best_ckpt_path, lecam_emas, is_best = \
49 0, 0, cfgs.OPTIMIZATION.batch_size, 0, None, None, None, False
50 mu, sigma, real_feats, eval_model, num_rows, num_cols = None, None, None, None, 10, 8
51 aa_p = cfgs.AUG.ada_initial_augment_p
52 if cfgs.AUG.ada_initial_augment_p != "N/A":
53 aa_p = cfgs.AUG.ada_initial_augment_p
54 else:
55 aa_p = cfgs.AUG.apa_initial_augment_p
56
57 loss_list_dict = {"gen_loss": [], "dis_loss": [], "cls_loss": []}
58 num_eval = {}
59 metric_dict_during_train = {}
60 if "none" in cfgs.RUN.eval_metrics:
61 cfgs.RUN.eval_metrics = []
62 if "is" in cfgs.RUN.eval_metrics:
63 metric_dict_during_train.update({"IS": [], "Top1_acc": [], "Top5_acc": []})
64 if "fid" in cfgs.RUN.eval_metrics:
65 metric_dict_during_train.update({"FID": []})
66 if "prdc" in cfgs.RUN.eval_metrics:
67 metric_dict_during_train.update({"Improved_Precision": [], "Improved_Recall": [], "Density":[], "Coverage": []})
68
69 # -----------------------------------------------------------------------------
70 # determine cuda, cudnn, and backends settings.
71 # -----------------------------------------------------------------------------
72 if cfgs.RUN.fix_seed:
73 cudnn.benchmark, cudnn.deterministic = False, True
74 else:
75 cudnn.benchmark, cudnn.deterministic = True, False
76
77 if cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]:
78 # Improves training speed
79 conv2d_gradfix.enabled = True
80 # Avoids errors with the augmentation pipe
81 grid_sample_gradfix.enabled = True
82 if cfgs.RUN.mixed_precision:
83 # Allow PyTorch to internally use tf32 for matmul
84 torch.backends.cuda.matmul.allow_tf32 = False
85 # Allow PyTorch to internally use tf32 for convolutions
86 torch.backends.cudnn.allow_tf32 = False
87
88 # -----------------------------------------------------------------------------
89 # initialize all processes and fix seed of each process
90 # -----------------------------------------------------------------------------
91 if cfgs.RUN.distributed_data_parallel:
92 global_rank = cfgs.RUN.current_node * (gpus_per_node) + local_rank
93 print("Use GPU: {global_rank} for training.".format(global_rank=global_rank))
94 misc.setup(global_rank, cfgs.OPTIMIZATION.world_size, cfgs.RUN.backend)
95 torch.cuda.set_device(local_rank)
96 else:

Callers

nothing calls this directly

Calls 15

prepare_train_iterMethod · 0.95
train_discriminatorMethod · 0.95
train_generatorMethod · 0.95
log_train_statisticsMethod · 0.95
visualize_fake_imagesMethod · 0.95
evaluateMethod · 0.95
saveMethod · 0.95
save_real_imagesMethod · 0.95
save_fake_imagesMethod · 0.95

Tested by

no test coverage detected