(self, dy_model, metrics_list, batch_data, config)
| 105 | |
| 106 | # construct train forward phase |
| 107 | def train_forward(self, dy_model, metrics_list, batch_data, config): |
| 108 | edges, node_feat, edge_feat, segment_ids, labels = self.create_feeds( |
| 109 | batch_data, config) |
| 110 | # predict |
| 111 | output, l0_penaty, l2_penaty = dy_model.forward( |
| 112 | edges, node_feat, edge_feat, segment_ids, True) |
| 113 | # get loss |
| 114 | l0_weight = config.get("hyper_parameters.l0_weight", 0.001) |
| 115 | l2_weight = config.get("hyper_parameters.l0_weight", 0.001) |
| 116 | loss = self.create_loss(output, labels, l0_penaty, l2_penaty, |
| 117 | l0_weight, l2_weight) |
| 118 | # update metrics |
| 119 | predictions = np.vstack(output) |
| 120 | labels = np.vstack(labels) |
| 121 | labels = labels[:, 1].reshape((-1, 1)) |
| 122 | metrics_list[0].update(preds=predictions, labels=labels) |
| 123 | correct = metrics_list[1].compute( |
| 124 | paddle.to_tensor(predictions), paddle.to_tensor(labels)) |
| 125 | metrics_list[1].update(correct) |
| 126 | # print dict |
| 127 | print_dict = {'loss': loss} |
| 128 | return loss, metrics_list, print_dict |
| 129 | |
| 130 | # construct infer forward phase |
| 131 | def infer_forward(self, dy_model, metrics_list, batch_data, config): |
nothing calls this directly
no test coverage detected