| 18 | |
| 19 | |
| 20 | class SparseFeat(namedtuple('SparseFeat', |
| 21 | ['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embedding_name', |
| 22 | 'group_name'])): |
| 23 | __slots__ = () |
| 24 | |
| 25 | def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="int32", embedding_name=None, |
| 26 | group_name=DEFAULT_GROUP_NAME): |
| 27 | if embedding_name is None: |
| 28 | embedding_name = name |
| 29 | if embedding_dim == "auto": |
| 30 | embedding_dim = 6 * int(pow(vocabulary_size, 0.25)) |
| 31 | if use_hash: |
| 32 | print( |
| 33 | "Notice! Feature Hashing on the fly currently is not supported in torch version,you can use tensorflow version!") |
| 34 | return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype, |
| 35 | embedding_name, group_name) |
| 36 | |
| 37 | def __hash__(self): |
| 38 | return self.name.__hash__() |
| 39 | |
| 40 | |
| 41 | class VarLenSparseFeat(namedtuple('VarLenSparseFeat', |
no outgoing calls