(self, dy_model, metrics_list, batch_data, config)
| 64 | |
| 65 | # construct train forward phase |
| 66 | def train_forward(self, dy_model, metrics_list, batch_data, config): |
| 67 | *inputs, labels = self.create_feeds(batch_data) |
| 68 | labels = labels.argmax(-1, keepdim=True) |
| 69 | |
| 70 | prediction = dy_model.forward(*inputs) |
| 71 | loss = self.create_loss(prediction, labels) |
| 72 | # update metrics |
| 73 | print_dict = {"loss": loss} |
| 74 | correct = metrics_list[0].compute(prediction, labels) |
| 75 | metrics_list[0].update(correct) |
| 76 | return loss, metrics_list, print_dict |
| 77 | |
| 78 | def infer_forward(self, dy_model, metrics_list, batch_data, config): |
| 79 | inputs = self.create_feeds(batch_data) |
no test coverage detected