(self, X)
| 79 | self.to(device) |
| 80 | |
| 81 | def forward(self, X): |
| 82 | sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns, |
| 83 | self.embedding_dict) |
| 84 | if not len(sparse_embedding_list) > 0: |
| 85 | raise ValueError("there are no sparse features") |
| 86 | |
| 87 | att_input = concat_fun(sparse_embedding_list, axis=1) |
| 88 | att_out = self.vector_wise_net(att_input) |
| 89 | att_out = att_out.reshape(att_out.shape[0], -1) |
| 90 | m_vec = self.transform_matrix_P_vec(att_out) |
| 91 | |
| 92 | dnn_input = combined_dnn_input(sparse_embedding_list, []) |
| 93 | dnn_output = self.bit_wise_net(dnn_input) |
| 94 | m_bit = self.transform_matrix_P_bit(dnn_output) |
| 95 | |
| 96 | m_x = m_vec + m_bit # m_x is the complete input-aware factor |
| 97 | |
| 98 | logit = self.linear_model(X, sparse_feat_refine_weight=m_x) |
| 99 | |
| 100 | fm_input = torch.cat(sparse_embedding_list, dim=1) |
| 101 | refined_fm_input = fm_input * m_x.unsqueeze(-1) # \textbf{v}_{x,i}=m_{x,i} * \textbf{v}_i |
| 102 | logit += self.fm(refined_fm_input) |
| 103 | |
| 104 | y_pred = self.out(logit) |
| 105 | |
| 106 | return y_pred |
nothing calls this directly
no test coverage detected