MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / evaluate

Method evaluate

utils/commons/trainer.py:256–304  ·  view source on GitHub ↗
(self, task, test=False, tqdm_desc='Valid', max_batches=None)

Source from the content-addressed store, hash-verified

254 self.save_checkpoint(epoch=self.current_epoch, logs=eval_results)
255
256 def evaluate(self, task, test=False, tqdm_desc='Valid', max_batches=None):
257 if max_batches == -1:
258 max_batches = None
259 # enable eval mode
260 task.zero_grad()
261 task.eval()
262 torch.set_grad_enabled(False)
263
264 task_ref = self.get_task_ref()
265 if test:
266 ret = task_ref.test_start()
267 if ret == 'EXIT':
268 return
269 else:
270 task_ref.validation_start()
271 outputs = []
272 dataloader = task_ref.test_dataloader() if test else task_ref.val_dataloader()
273 pbar = tqdm.tqdm(dataloader, desc=tqdm_desc, total=max_batches, dynamic_ncols=True, unit='step',
274 disable=self.root_gpu > 0)
275 # give model a chance to do something with the outputs (and method defined)
276 for batch_idx, batch in enumerate(pbar):
277 if batch is None: # pragma: no cover
278 continue
279 # stop short when on fast_dev_run (sets max_batch=1)
280 if max_batches is not None and batch_idx >= max_batches:
281 break
282
283 # make dataloader_idx arg in validation_step optional
284 if self.on_gpu:
285 batch = move_to_cuda(batch, self.root_gpu)
286 args = [batch, batch_idx]
287 if self.use_ddp:
288 output = task(*args)
289 else:
290 if test:
291 output = task_ref.test_step(*args)
292 else:
293 output = task_ref.validation_step(*args)
294 # track outputs for collation
295 outputs.append(output)
296 # give model a chance to do something with the outputs (and method defined)
297 if test:
298 eval_results = task_ref.test_end(outputs)
299 else:
300 eval_results = task_ref.validation_end(outputs)
301 # enable train mode again
302 task.train()
303 torch.set_grad_enabled(True)
304 return eval_results
305
306 ####################
307 # train

Callers 2

run_evaluationMethod · 0.95
trainMethod · 0.95

Calls 12

get_task_refMethod · 0.95
move_to_cudaFunction · 0.90
appendMethod · 0.80
test_startMethod · 0.45
validation_startMethod · 0.45
test_dataloaderMethod · 0.45
val_dataloaderMethod · 0.45
test_stepMethod · 0.45
validation_stepMethod · 0.45
test_endMethod · 0.45
validation_endMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected