(local_rank, cfgs, gpus_per_node, run_name, hdf5_path)
| 37 | |
| 38 | |
| 39 | def load_worker(local_rank, cfgs, gpus_per_node, run_name, hdf5_path): |
| 40 | # ----------------------------------------------------------------------------- |
| 41 | # define default variables for loading ckpt or evaluating the trained GAN model. |
| 42 | # ----------------------------------------------------------------------------- |
| 43 | |
| 44 | load_train_dataset = cfgs.RUN.train + cfgs.RUN.GAN_train + cfgs.RUN.GAN_test |
| 45 | load_eval_dataset = len(cfgs.RUN.eval_metrics) + cfgs.RUN.save_real_images + cfgs.RUN.k_nearest_neighbor + \ |
| 46 | cfgs.RUN.frequency_analysis + cfgs.RUN.tsne_analysis + cfgs.RUN.intra_class_fid |
| 47 | train_sampler, eval_sampler = None, None |
| 48 | step, epoch, topk, best_step, best_fid, best_ckpt_path, lecam_emas, is_best = \ |
| 49 | 0, 0, cfgs.OPTIMIZATION.batch_size, 0, None, None, None, False |
| 50 | mu, sigma, real_feats, eval_model, num_rows, num_cols = None, None, None, None, 10, 8 |
| 51 | aa_p = cfgs.AUG.ada_initial_augment_p |
| 52 | if cfgs.AUG.ada_initial_augment_p != "N/A": |
| 53 | aa_p = cfgs.AUG.ada_initial_augment_p |
| 54 | else: |
| 55 | aa_p = cfgs.AUG.apa_initial_augment_p |
| 56 | |
| 57 | loss_list_dict = {"gen_loss": [], "dis_loss": [], "cls_loss": []} |
| 58 | num_eval = {} |
| 59 | metric_dict_during_train = {} |
| 60 | if "none" in cfgs.RUN.eval_metrics: |
| 61 | cfgs.RUN.eval_metrics = [] |
| 62 | if "is" in cfgs.RUN.eval_metrics: |
| 63 | metric_dict_during_train.update({"IS": [], "Top1_acc": [], "Top5_acc": []}) |
| 64 | if "fid" in cfgs.RUN.eval_metrics: |
| 65 | metric_dict_during_train.update({"FID": []}) |
| 66 | if "prdc" in cfgs.RUN.eval_metrics: |
| 67 | metric_dict_during_train.update({"Improved_Precision": [], "Improved_Recall": [], "Density":[], "Coverage": []}) |
| 68 | |
| 69 | # ----------------------------------------------------------------------------- |
| 70 | # determine cuda, cudnn, and backends settings. |
| 71 | # ----------------------------------------------------------------------------- |
| 72 | if cfgs.RUN.fix_seed: |
| 73 | cudnn.benchmark, cudnn.deterministic = False, True |
| 74 | else: |
| 75 | cudnn.benchmark, cudnn.deterministic = True, False |
| 76 | |
| 77 | if cfgs.MODEL.backbone in ["stylegan2", "stylegan3"]: |
| 78 | # Improves training speed |
| 79 | conv2d_gradfix.enabled = True |
| 80 | # Avoids errors with the augmentation pipe |
| 81 | grid_sample_gradfix.enabled = True |
| 82 | if cfgs.RUN.mixed_precision: |
| 83 | # Allow PyTorch to internally use tf32 for matmul |
| 84 | torch.backends.cuda.matmul.allow_tf32 = False |
| 85 | # Allow PyTorch to internally use tf32 for convolutions |
| 86 | torch.backends.cudnn.allow_tf32 = False |
| 87 | |
| 88 | # ----------------------------------------------------------------------------- |
| 89 | # initialize all processes and fix seed of each process |
| 90 | # ----------------------------------------------------------------------------- |
| 91 | if cfgs.RUN.distributed_data_parallel: |
| 92 | global_rank = cfgs.RUN.current_node * (gpus_per_node) + local_rank |
| 93 | print("Use GPU: {global_rank} for training.".format(global_rank=global_rank)) |
| 94 | misc.setup(global_rank, cfgs.OPTIMIZATION.world_size, cfgs.RUN.backend) |
| 95 | torch.cuda.set_device(local_rank) |
| 96 | else: |
nothing calls this directly
no test coverage detected