MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / forward

Method forward

deepctr_torch/models/difm.py:81–106  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

79 self.to(device)
80
81 def forward(self, X):
82 sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
83 self.embedding_dict)
84 if not len(sparse_embedding_list) > 0:
85 raise ValueError("there are no sparse features")
86
87 att_input = concat_fun(sparse_embedding_list, axis=1)
88 att_out = self.vector_wise_net(att_input)
89 att_out = att_out.reshape(att_out.shape[0], -1)
90 m_vec = self.transform_matrix_P_vec(att_out)
91
92 dnn_input = combined_dnn_input(sparse_embedding_list, [])
93 dnn_output = self.bit_wise_net(dnn_input)
94 m_bit = self.transform_matrix_P_bit(dnn_output)
95
96 m_x = m_vec + m_bit # m_x is the complete input-aware factor
97
98 logit = self.linear_model(X, sparse_feat_refine_weight=m_x)
99
100 fm_input = torch.cat(sparse_embedding_list, dim=1)
101 refined_fm_input = fm_input * m_x.unsqueeze(-1) # \textbf{v}_{x,i}=m_{x,i} * \textbf{v}_i
102 logit += self.fm(refined_fm_input)
103
104 y_pred = self.out(logit)
105
106 return y_pred

Callers

nothing calls this directly

Calls 3

concat_funFunction · 0.85
combined_dnn_inputFunction · 0.85

Tested by

no test coverage detected