(local_rank, args, world_size, gpus_per_node)
| 110 | |
| 111 | |
| 112 | def evaluate(local_rank, args, world_size, gpus_per_node): |
| 113 | # ----------------------------------------------------------------------------- |
| 114 | # determine cuda, cudnn, and backends settings. |
| 115 | # ----------------------------------------------------------------------------- |
| 116 | cudnn.benchmark, cudnn.deterministic = False, True |
| 117 | |
| 118 | # ----------------------------------------------------------------------------- |
| 119 | # initialize all processes and fix seed of each process |
| 120 | # ----------------------------------------------------------------------------- |
| 121 | if args.distributed_data_parallel: |
| 122 | global_rank = args.current_node * (gpus_per_node) + local_rank |
| 123 | print("Use GPU: {global_rank} for training.".format(global_rank=global_rank)) |
| 124 | misc.setup(global_rank, world_size, args.backend) |
| 125 | torch.cuda.set_device(local_rank) |
| 126 | else: |
| 127 | global_rank = local_rank |
| 128 | |
| 129 | misc.fix_seed(args.seed + global_rank) |
| 130 | |
| 131 | # ----------------------------------------------------------------------------- |
| 132 | # load dset1 and dset1. |
| 133 | # ----------------------------------------------------------------------------- |
| 134 | load_dset1 = ("fid" in args.eval_metrics and args.dset1_moments == None)*\ |
| 135 | ("prdc" in args.eval_metrics and args.dset1_feats == None) |
| 136 | if load_dset1: |
| 137 | dset1 = Dataset_(data_dir=args.dset1) |
| 138 | if local_rank == 0: |
| 139 | print("Size of dset1: {dataset_size}".format(dataset_size=len(dset1))) |
| 140 | |
| 141 | dset2 = Dataset_(data_dir=args.dset2) |
| 142 | if local_rank == 0: |
| 143 | print("Size of dset2: {dataset_size}".format(dataset_size=len(dset2))) |
| 144 | |
| 145 | # ----------------------------------------------------------------------------- |
| 146 | # define a distributed sampler for DDP evaluation. |
| 147 | # ----------------------------------------------------------------------------- |
| 148 | if args.distributed_data_parallel: |
| 149 | batch_size = args.batch_size//world_size |
| 150 | if load_dset1: |
| 151 | dset1_sampler = DistributedSampler(dset1, |
| 152 | num_replicas=world_size, |
| 153 | rank=local_rank, |
| 154 | shuffle=False, |
| 155 | drop_last=False) |
| 156 | |
| 157 | dset2_sampler = DistributedSampler(dset2, |
| 158 | num_replicas=world_size, |
| 159 | rank=local_rank, |
| 160 | shuffle=False, |
| 161 | drop_last=False) |
| 162 | else: |
| 163 | batch_size = args.batch_size |
| 164 | dset1_sampler, dset2_sampler = None, None |
| 165 | |
| 166 | # ----------------------------------------------------------------------------- |
| 167 | # define dataloaders for dset1 and dset2. |
| 168 | # ----------------------------------------------------------------------------- |
| 169 | if load_dset1: |
no test coverage detected