(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu',
l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024,
task='binary', device='cpu', gpus=None)
| 37 | """ |
| 38 | |
| 39 | def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3, |
| 40 | att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu', |
| 41 | l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024, |
| 42 | task='binary', device='cpu', gpus=None): |
| 43 | |
| 44 | super(AutoInt, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=0, |
| 45 | l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, |
| 46 | device=device, gpus=gpus) |
| 47 | if len(dnn_hidden_units) <= 0 and att_layer_num <= 0: |
| 48 | raise ValueError("Either hidden_layer or att_layer_num must > 0") |
| 49 | self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0 |
| 50 | field_num = len(self.embedding_dict) |
| 51 | |
| 52 | embedding_size = self.embedding_size |
| 53 | |
| 54 | if len(dnn_hidden_units) and att_layer_num > 0: |
| 55 | dnn_linear_in_feature = dnn_hidden_units[-1] + field_num * embedding_size |
| 56 | elif len(dnn_hidden_units) > 0: |
| 57 | dnn_linear_in_feature = dnn_hidden_units[-1] |
| 58 | elif att_layer_num > 0: |
| 59 | dnn_linear_in_feature = field_num * embedding_size |
| 60 | else: |
| 61 | raise NotImplementedError |
| 62 | |
| 63 | self.dnn_linear = nn.Linear(dnn_linear_in_feature, 1, bias=False).to(device) |
| 64 | self.dnn_hidden_units = dnn_hidden_units |
| 65 | self.att_layer_num = att_layer_num |
| 66 | if self.use_dnn: |
| 67 | self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units, |
| 68 | activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, |
| 69 | init_std=init_std, device=device) |
| 70 | self.add_regularization_weight( |
| 71 | filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) |
| 72 | self.int_layers = nn.ModuleList( |
| 73 | [InteractingLayer(embedding_size, att_head_num, att_res, device=device) for _ in range(att_layer_num)]) |
| 74 | |
| 75 | self.to(device) |
| 76 | |
| 77 | def forward(self, X): |
| 78 |
nothing calls this directly
no test coverage detected