Run evaluation code. :param model: PT model :param dataloaders: list of PT dataloaders :param max_batches: Scalar :param test: boolean :return:
(self, model, dataloaders, max_batches, test=False)
| 1144 | self.num_test_batches = float('inf') |
| 1145 | |
| 1146 | def evaluate(self, model, dataloaders, max_batches, test=False): |
| 1147 | """Run evaluation code. |
| 1148 | |
| 1149 | :param model: PT model |
| 1150 | :param dataloaders: list of PT dataloaders |
| 1151 | :param max_batches: Scalar |
| 1152 | :param test: boolean |
| 1153 | :return: |
| 1154 | """ |
| 1155 | # enable eval mode |
| 1156 | model.zero_grad() |
| 1157 | model.eval() |
| 1158 | |
| 1159 | # copy properties for forward overrides |
| 1160 | self.copy_trainer_model_properties(model) |
| 1161 | |
| 1162 | # disable gradients to save memory |
| 1163 | torch.set_grad_enabled(False) |
| 1164 | |
| 1165 | if test: |
| 1166 | self.get_model().test_start() |
| 1167 | # bookkeeping |
| 1168 | outputs = [] |
| 1169 | |
| 1170 | # run training |
| 1171 | for dataloader_idx, dataloader in enumerate(dataloaders): |
| 1172 | dl_outputs = [] |
| 1173 | for batch_idx, batch in enumerate(dataloader): |
| 1174 | |
| 1175 | if batch is None: # pragma: no cover |
| 1176 | continue |
| 1177 | |
| 1178 | # stop short when on fast_dev_run (sets max_batch=1) |
| 1179 | if batch_idx >= max_batches: |
| 1180 | break |
| 1181 | |
| 1182 | # ----------------- |
| 1183 | # RUN EVALUATION STEP |
| 1184 | # ----------------- |
| 1185 | output = self.evaluation_forward(model, |
| 1186 | batch, |
| 1187 | batch_idx, |
| 1188 | dataloader_idx, |
| 1189 | test) |
| 1190 | |
| 1191 | # track outputs for collation |
| 1192 | dl_outputs.append(output) |
| 1193 | |
| 1194 | # batch done |
| 1195 | if test: |
| 1196 | self.test_progress_bar.update(1) |
| 1197 | else: |
| 1198 | self.val_progress_bar.update(1) |
| 1199 | outputs.append(dl_outputs) |
| 1200 | |
| 1201 | # with a single dataloader don't pass an array |
| 1202 | if len(dataloaders) == 1: |
| 1203 | outputs = outputs[0] |
no test coverage detected