MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / __init__

Method __init__

deepctr_torch/models/difm.py:39–79  ·  view source on GitHub ↗
(self,
                 linear_feature_columns, dnn_feature_columns, att_head_num=4,
                 att_res=True, dnn_hidden_units=(256, 128),
                 l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
                 dnn_dropout=0,
                 dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None)

Source from the content-addressed store, hash-verified

37 """
38
39 def __init__(self,
40 linear_feature_columns, dnn_feature_columns, att_head_num=4,
41 att_res=True, dnn_hidden_units=(256, 128),
42 l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
43 dnn_dropout=0,
44 dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
45 super(DIFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
46 l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
47 device=device, gpus=gpus)
48
49 if not len(dnn_hidden_units) > 0:
50 raise ValueError("dnn_hidden_units is null!")
51
52 self.fm = FM()
53
54 # InteractingLayer (used in AutoInt) = multi-head self-attention + Residual Network
55 self.vector_wise_net = InteractingLayer(self.embedding_size, att_head_num,
56 att_res, scaling=True, device=device)
57
58 self.bit_wise_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
59 dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
60 dropout_rate=dnn_dropout,
61 use_bn=dnn_use_bn, init_std=init_std, device=device)
62 self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat),
63 dnn_feature_columns)))
64
65 self.transform_matrix_P_vec = nn.Linear(
66 self.sparse_feat_num * self.embedding_size, self.sparse_feat_num, bias=False).to(device)
67 self.transform_matrix_P_bit = nn.Linear(
68 dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device)
69
70 self.add_regularization_weight(
71 filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.vector_wise_net.named_parameters()),
72 l2=l2_reg_dnn)
73 self.add_regularization_weight(
74 filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.bit_wise_net.named_parameters()),
75 l2=l2_reg_dnn)
76 self.add_regularization_weight(self.transform_matrix_P_vec.weight, l2=l2_reg_dnn)
77 self.add_regularization_weight(self.transform_matrix_P_bit.weight, l2=l2_reg_dnn)
78
79 self.to(device)
80
81 def forward(self, X):
82 sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,

Callers

nothing calls this directly

Calls 5

FMClass · 0.85
InteractingLayerClass · 0.85
DNNClass · 0.85
compute_input_dimMethod · 0.45

Tested by

no test coverage detected