Args: X: input Tensor [batch_size x hidden_dim] sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} sparse_feature_columns: list, sparse features
(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
mask_feat_list=(), to_list=False)
| 181 | |
| 182 | |
| 183 | def embedding_lookup(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(), |
| 184 | mask_feat_list=(), to_list=False): |
| 185 | """ |
| 186 | Args: |
| 187 | X: input Tensor [batch_size x hidden_dim] |
| 188 | sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} |
| 189 | sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} |
| 190 | sparse_feature_columns: list, sparse features |
| 191 | return_feat_list: list, names of feature to be returned, defualt () -> return all features |
| 192 | mask_feat_list, list, names of feature to be masked in hash transform |
| 193 | Return: |
| 194 | group_embedding_dict: defaultdict(list) |
| 195 | """ |
| 196 | group_embedding_dict = defaultdict(list) |
| 197 | for fc in sparse_feature_columns: |
| 198 | feature_name = fc.name |
| 199 | embedding_name = fc.embedding_name |
| 200 | if (len(return_feat_list) == 0 or feature_name in return_feat_list): |
| 201 | # TODO: add hash function |
| 202 | # if fc.use_hash: |
| 203 | # raise NotImplementedError("hash function is not implemented in this version!") |
| 204 | lookup_idx = np.array(sparse_input_dict[feature_name]) |
| 205 | input_tensor = X[:, lookup_idx[0]:lookup_idx[1]].long() |
| 206 | emb = sparse_embedding_dict[embedding_name](input_tensor) |
| 207 | group_embedding_dict[fc.group_name].append(emb) |
| 208 | if to_list: |
| 209 | return list(chain.from_iterable(group_embedding_dict.values())) |
| 210 | return group_embedding_dict |
| 211 | |
| 212 | |
| 213 | def varlen_embedding_lookup(X, embedding_dict, sequence_input_dict, varlen_sparse_feature_columns): |
no test coverage detected