MCPcopy Index your code
hub / github.com/POSTECH-CVLab/PyTorch-StudioGAN / evaluate

Function evaluate

src/evaluate.py:112–288  ·  view source on GitHub ↗
(local_rank, args, world_size, gpus_per_node)

Source from the content-addressed store, hash-verified

110
111
112def evaluate(local_rank, args, world_size, gpus_per_node):
113 # -----------------------------------------------------------------------------
114 # determine cuda, cudnn, and backends settings.
115 # -----------------------------------------------------------------------------
116 cudnn.benchmark, cudnn.deterministic = False, True
117
118 # -----------------------------------------------------------------------------
119 # initialize all processes and fix seed of each process
120 # -----------------------------------------------------------------------------
121 if args.distributed_data_parallel:
122 global_rank = args.current_node * (gpus_per_node) + local_rank
123 print("Use GPU: {global_rank} for training.".format(global_rank=global_rank))
124 misc.setup(global_rank, world_size, args.backend)
125 torch.cuda.set_device(local_rank)
126 else:
127 global_rank = local_rank
128
129 misc.fix_seed(args.seed + global_rank)
130
131 # -----------------------------------------------------------------------------
132 # load dset1 and dset1.
133 # -----------------------------------------------------------------------------
134 load_dset1 = ("fid" in args.eval_metrics and args.dset1_moments == None)*\
135 ("prdc" in args.eval_metrics and args.dset1_feats == None)
136 if load_dset1:
137 dset1 = Dataset_(data_dir=args.dset1)
138 if local_rank == 0:
139 print("Size of dset1: {dataset_size}".format(dataset_size=len(dset1)))
140
141 dset2 = Dataset_(data_dir=args.dset2)
142 if local_rank == 0:
143 print("Size of dset2: {dataset_size}".format(dataset_size=len(dset2)))
144
145 # -----------------------------------------------------------------------------
146 # define a distributed sampler for DDP evaluation.
147 # -----------------------------------------------------------------------------
148 if args.distributed_data_parallel:
149 batch_size = args.batch_size//world_size
150 if load_dset1:
151 dset1_sampler = DistributedSampler(dset1,
152 num_replicas=world_size,
153 rank=local_rank,
154 shuffle=False,
155 drop_last=False)
156
157 dset2_sampler = DistributedSampler(dset2,
158 num_replicas=world_size,
159 rank=local_rank,
160 shuffle=False,
161 drop_last=False)
162 else:
163 batch_size = args.batch_size
164 dset1_sampler, dset2_sampler = None, None
165
166 # -----------------------------------------------------------------------------
167 # define dataloaders for dset1 and dset2.
168 # -----------------------------------------------------------------------------
169 if load_dset1:

Callers 1

evaluate.pyFile · 0.85

Calls 2

Dataset_Class · 0.70
updateMethod · 0.45

Tested by

no test coverage detected