MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / forward

Method forward

deepctr_torch/models/pnn.py:78–109  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

76 self.to(device)
77
78 def forward(self, X):
79
80 sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
81 self.embedding_dict)
82 linear_signal = torch.flatten(
83 concat_fun(sparse_embedding_list), start_dim=1)
84
85 if self.use_inner:
86 inner_product = torch.flatten(
87 self.innerproduct(sparse_embedding_list), start_dim=1)
88
89 if self.use_outter:
90 outer_product = self.outterproduct(sparse_embedding_list)
91
92 if self.use_outter and self.use_inner:
93 product_layer = torch.cat(
94 [linear_signal, inner_product, outer_product], dim=1)
95 elif self.use_outter:
96 product_layer = torch.cat([linear_signal, outer_product], dim=1)
97 elif self.use_inner:
98 product_layer = torch.cat([linear_signal, inner_product], dim=1)
99 else:
100 product_layer = linear_signal
101
102 dnn_input = combined_dnn_input([product_layer], dense_value_list)
103 dnn_output = self.dnn(dnn_input)
104 dnn_logit = self.dnn_linear(dnn_output)
105 logit = dnn_logit
106
107 y_pred = self.out(logit)
108
109 return y_pred

Callers

nothing calls this directly

Calls 3

concat_funFunction · 0.85
combined_dnn_inputFunction · 0.85

Tested by

no test coverage detected