(self, X)
| 76 | self.to(device) |
| 77 | |
| 78 | def forward(self, X): |
| 79 | |
| 80 | sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, |
| 81 | self.embedding_dict) |
| 82 | linear_signal = torch.flatten( |
| 83 | concat_fun(sparse_embedding_list), start_dim=1) |
| 84 | |
| 85 | if self.use_inner: |
| 86 | inner_product = torch.flatten( |
| 87 | self.innerproduct(sparse_embedding_list), start_dim=1) |
| 88 | |
| 89 | if self.use_outter: |
| 90 | outer_product = self.outterproduct(sparse_embedding_list) |
| 91 | |
| 92 | if self.use_outter and self.use_inner: |
| 93 | product_layer = torch.cat( |
| 94 | [linear_signal, inner_product, outer_product], dim=1) |
| 95 | elif self.use_outter: |
| 96 | product_layer = torch.cat([linear_signal, outer_product], dim=1) |
| 97 | elif self.use_inner: |
| 98 | product_layer = torch.cat([linear_signal, inner_product], dim=1) |
| 99 | else: |
| 100 | product_layer = linear_signal |
| 101 | |
| 102 | dnn_input = combined_dnn_input([product_layer], dense_value_list) |
| 103 | dnn_output = self.dnn(dnn_input) |
| 104 | dnn_logit = self.dnn_linear(dnn_output) |
| 105 | logit = dnn_logit |
| 106 | |
| 107 | y_pred = self.out(logit) |
| 108 | |
| 109 | return y_pred |
nothing calls this directly
no test coverage detected