(self, model, batch, batch_idx, dataloader_idx, test=False)
| 1279 | logs=self.callback_metrics) |
| 1280 | |
| 1281 | def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): |
| 1282 | # make dataloader_idx arg in validation_step optional |
| 1283 | args = [batch, batch_idx] |
| 1284 | |
| 1285 | if test and len(self.get_test_dataloaders()) > 1: |
| 1286 | args.append(dataloader_idx) |
| 1287 | |
| 1288 | elif not test and len(self.get_val_dataloaders()) > 1: |
| 1289 | args.append(dataloader_idx) |
| 1290 | |
| 1291 | # handle DP, DDP forward |
| 1292 | if self.use_ddp or self.use_dp: |
| 1293 | output = model(*args) |
| 1294 | return output |
| 1295 | |
| 1296 | # single GPU |
| 1297 | if self.single_gpu: |
| 1298 | # for single GPU put inputs on gpu manually |
| 1299 | root_gpu = 0 |
| 1300 | if isinstance(self.data_parallel_device_ids, list): |
| 1301 | root_gpu = self.data_parallel_device_ids[0] |
| 1302 | batch = self.transfer_batch_to_gpu(batch, root_gpu) |
| 1303 | args[0] = batch |
| 1304 | |
| 1305 | # CPU |
| 1306 | if test: |
| 1307 | output = model.test_step(*args) |
| 1308 | else: |
| 1309 | output = model.validation_step(*args) |
| 1310 | |
| 1311 | return output |
| 1312 | |
| 1313 | def train(self): |
| 1314 | model = self.get_model() |
no test coverage detected