MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / embedding_lookup

Function embedding_lookup

deepctr_torch/inputs.py:183–210  ·  view source on GitHub ↗

Args: X: input Tensor [batch_size x hidden_dim] sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} sparse_feature_columns: list, sparse features

(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
                     mask_feat_list=(), to_list=False)

Source from the content-addressed store, hash-verified

181
182
183def embedding_lookup(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
184 mask_feat_list=(), to_list=False):
185 """
186 Args:
187 X: input Tensor [batch_size x hidden_dim]
188 sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding}
189 sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)}
190 sparse_feature_columns: list, sparse features
191 return_feat_list: list, names of feature to be returned, defualt () -> return all features
192 mask_feat_list, list, names of feature to be masked in hash transform
193 Return:
194 group_embedding_dict: defaultdict(list)
195 """
196 group_embedding_dict = defaultdict(list)
197 for fc in sparse_feature_columns:
198 feature_name = fc.name
199 embedding_name = fc.embedding_name
200 if (len(return_feat_list) == 0 or feature_name in return_feat_list):
201 # TODO: add hash function
202 # if fc.use_hash:
203 # raise NotImplementedError("hash function is not implemented in this version!")
204 lookup_idx = np.array(sparse_input_dict[feature_name])
205 input_tensor = X[:, lookup_idx[0]:lookup_idx[1]].long()
206 emb = sparse_embedding_dict[embedding_name](input_tensor)
207 group_embedding_dict[fc.group_name].append(emb)
208 if to_list:
209 return list(chain.from_iterable(group_embedding_dict.values()))
210 return group_embedding_dict
211
212
213def varlen_embedding_lookup(X, embedding_dict, sequence_input_dict, varlen_sparse_feature_columns):

Callers 3

_get_embMethod · 0.85
_get_deep_input_embMethod · 0.85
forwardMethod · 0.85

Calls 1

appendMethod · 0.80

Tested by

no test coverage detected