(self, task, test=False, tqdm_desc='Valid', max_batches=None)
| 254 | self.save_checkpoint(epoch=self.current_epoch, logs=eval_results) |
| 255 | |
| 256 | def evaluate(self, task, test=False, tqdm_desc='Valid', max_batches=None): |
| 257 | if max_batches == -1: |
| 258 | max_batches = None |
| 259 | # enable eval mode |
| 260 | task.zero_grad() |
| 261 | task.eval() |
| 262 | torch.set_grad_enabled(False) |
| 263 | |
| 264 | task_ref = self.get_task_ref() |
| 265 | if test: |
| 266 | ret = task_ref.test_start() |
| 267 | if ret == 'EXIT': |
| 268 | return |
| 269 | else: |
| 270 | task_ref.validation_start() |
| 271 | outputs = [] |
| 272 | dataloader = task_ref.test_dataloader() if test else task_ref.val_dataloader() |
| 273 | pbar = tqdm.tqdm(dataloader, desc=tqdm_desc, total=max_batches, dynamic_ncols=True, unit='step', |
| 274 | disable=self.root_gpu > 0) |
| 275 | # give model a chance to do something with the outputs (and method defined) |
| 276 | for batch_idx, batch in enumerate(pbar): |
| 277 | if batch is None: # pragma: no cover |
| 278 | continue |
| 279 | # stop short when on fast_dev_run (sets max_batch=1) |
| 280 | if max_batches is not None and batch_idx >= max_batches: |
| 281 | break |
| 282 | |
| 283 | # make dataloader_idx arg in validation_step optional |
| 284 | if self.on_gpu: |
| 285 | batch = move_to_cuda(batch, self.root_gpu) |
| 286 | args = [batch, batch_idx] |
| 287 | if self.use_ddp: |
| 288 | output = task(*args) |
| 289 | else: |
| 290 | if test: |
| 291 | output = task_ref.test_step(*args) |
| 292 | else: |
| 293 | output = task_ref.validation_step(*args) |
| 294 | # track outputs for collation |
| 295 | outputs.append(output) |
| 296 | # give model a chance to do something with the outputs (and method defined) |
| 297 | if test: |
| 298 | eval_results = task_ref.test_end(outputs) |
| 299 | else: |
| 300 | eval_results = task_ref.validation_end(outputs) |
| 301 | # enable train mode again |
| 302 | task.train() |
| 303 | torch.set_grad_enabled(True) |
| 304 | return eval_results |
| 305 | |
| 306 | #################### |
| 307 | # train |
no test coverage detected