MCPcopy Index your code
hub / github.com/MoonInTheRiver/DiffSinger / evaluate

Method evaluate

utils/pl_utils.py:1146–1219  ·  view source on GitHub ↗

Run evaluation code. :param model: PT model :param dataloaders: list of PT dataloaders :param max_batches: Scalar :param test: boolean :return:

(self, model, dataloaders, max_batches, test=False)

Source from the content-addressed store, hash-verified

1144 self.num_test_batches = float('inf')
1145
1146 def evaluate(self, model, dataloaders, max_batches, test=False):
1147 """Run evaluation code.
1148
1149 :param model: PT model
1150 :param dataloaders: list of PT dataloaders
1151 :param max_batches: Scalar
1152 :param test: boolean
1153 :return:
1154 """
1155 # enable eval mode
1156 model.zero_grad()
1157 model.eval()
1158
1159 # copy properties for forward overrides
1160 self.copy_trainer_model_properties(model)
1161
1162 # disable gradients to save memory
1163 torch.set_grad_enabled(False)
1164
1165 if test:
1166 self.get_model().test_start()
1167 # bookkeeping
1168 outputs = []
1169
1170 # run training
1171 for dataloader_idx, dataloader in enumerate(dataloaders):
1172 dl_outputs = []
1173 for batch_idx, batch in enumerate(dataloader):
1174
1175 if batch is None: # pragma: no cover
1176 continue
1177
1178 # stop short when on fast_dev_run (sets max_batch=1)
1179 if batch_idx >= max_batches:
1180 break
1181
1182 # -----------------
1183 # RUN EVALUATION STEP
1184 # -----------------
1185 output = self.evaluation_forward(model,
1186 batch,
1187 batch_idx,
1188 dataloader_idx,
1189 test)
1190
1191 # track outputs for collation
1192 dl_outputs.append(output)
1193
1194 # batch done
1195 if test:
1196 self.test_progress_bar.update(1)
1197 else:
1198 self.val_progress_bar.update(1)
1199 outputs.append(dl_outputs)
1200
1201 # with a single dataloader don't pass an array
1202 if len(dataloaders) == 1:
1203 outputs = outputs[0]

Callers 2

run_pretrain_routineMethod · 0.95
run_evaluationMethod · 0.95

Calls 8

get_modelMethod · 0.95
evaluation_forwardMethod · 0.95
updateMethod · 0.80
validation_endMethod · 0.80
trainMethod · 0.80
test_startMethod · 0.45
test_endMethod · 0.45

Tested by

no test coverage detected