MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / train_forward

Method train_forward

models/match/kim/dygraph_model.py:66–76  ·  view source on GitHub ↗
(self, dy_model, metrics_list, batch_data, config)

Source from the content-addressed store, hash-verified

64
65 # construct train forward phase
66 def train_forward(self, dy_model, metrics_list, batch_data, config):
67 *inputs, labels = self.create_feeds(batch_data)
68 labels = labels.argmax(-1, keepdim=True)
69
70 prediction = dy_model.forward(*inputs)
71 loss = self.create_loss(prediction, labels)
72 # update metrics
73 print_dict = {"loss": loss}
74 correct = metrics_list[0].compute(prediction, labels)
75 metrics_list[0].update(correct)
76 return loss, metrics_list, print_dict
77
78 def infer_forward(self, dy_model, metrics_list, batch_data, config):
79 inputs = self.create_feeds(batch_data)

Callers 1

mainFunction · 0.45

Calls 5

create_feedsMethod · 0.95
create_lossMethod · 0.95
computeMethod · 0.80
updateMethod · 0.80
forwardMethod · 0.45

Tested by

no test coverage detected