MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / train_forward

Method train_forward

models/rank/sign/dygraph_model.py:107–128  ·  view source on GitHub ↗
(self, dy_model, metrics_list, batch_data, config)

Source from the content-addressed store, hash-verified

105
106 # construct train forward phase
107 def train_forward(self, dy_model, metrics_list, batch_data, config):
108 edges, node_feat, edge_feat, segment_ids, labels = self.create_feeds(
109 batch_data, config)
110 # predict
111 output, l0_penaty, l2_penaty = dy_model.forward(
112 edges, node_feat, edge_feat, segment_ids, True)
113 # get loss
114 l0_weight = config.get("hyper_parameters.l0_weight", 0.001)
115 l2_weight = config.get("hyper_parameters.l0_weight", 0.001)
116 loss = self.create_loss(output, labels, l0_penaty, l2_penaty,
117 l0_weight, l2_weight)
118 # update metrics
119 predictions = np.vstack(output)
120 labels = np.vstack(labels)
121 labels = labels[:, 1].reshape((-1, 1))
122 metrics_list[0].update(preds=predictions, labels=labels)
123 correct = metrics_list[1].compute(
124 paddle.to_tensor(predictions), paddle.to_tensor(labels))
125 metrics_list[1].update(correct)
126 # print dict
127 print_dict = {'loss': loss}
128 return loss, metrics_list, print_dict
129
130 # construct infer forward phase
131 def infer_forward(self, dy_model, metrics_list, batch_data, config):

Callers

nothing calls this directly

Calls 5

create_feedsMethod · 0.95
create_lossMethod · 0.95
updateMethod · 0.80
computeMethod · 0.80
forwardMethod · 0.45

Tested by

no test coverage detected