MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / __init__

Method __init__

deepctr_torch/models/autoint.py:39–75  ·  view source on GitHub ↗
(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
                 att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu',
                 l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024,
                 task='binary', device='cpu', gpus=None)

Source from the content-addressed store, hash-verified

37 """
38
39 def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
40 att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu',
41 l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024,
42 task='binary', device='cpu', gpus=None):
43
44 super(AutoInt, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=0,
45 l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
46 device=device, gpus=gpus)
47 if len(dnn_hidden_units) <= 0 and att_layer_num <= 0:
48 raise ValueError("Either hidden_layer or att_layer_num must > 0")
49 self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0
50 field_num = len(self.embedding_dict)
51
52 embedding_size = self.embedding_size
53
54 if len(dnn_hidden_units) and att_layer_num > 0:
55 dnn_linear_in_feature = dnn_hidden_units[-1] + field_num * embedding_size
56 elif len(dnn_hidden_units) > 0:
57 dnn_linear_in_feature = dnn_hidden_units[-1]
58 elif att_layer_num > 0:
59 dnn_linear_in_feature = field_num * embedding_size
60 else:
61 raise NotImplementedError
62
63 self.dnn_linear = nn.Linear(dnn_linear_in_feature, 1, bias=False).to(device)
64 self.dnn_hidden_units = dnn_hidden_units
65 self.att_layer_num = att_layer_num
66 if self.use_dnn:
67 self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units,
68 activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn,
69 init_std=init_std, device=device)
70 self.add_regularization_weight(
71 filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
72 self.int_layers = nn.ModuleList(
73 [InteractingLayer(embedding_size, att_head_num, att_res, device=device) for _ in range(att_layer_num)])
74
75 self.to(device)
76
77 def forward(self, X):
78

Callers

nothing calls this directly

Calls 4

DNNClass · 0.85
InteractingLayerClass · 0.85
compute_input_dimMethod · 0.45

Tested by

no test coverage detected