MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / Embedding

Class Embedding

model/embedding.py:62–127  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

60
61
62class Embedding(torch.nn.Module):
63 def __init__(self, dict_map, embedding_dim, name, config, padding_idx=None,
64 pretrained_embedding_file=None, mode=EmbeddingProcessType.FLAT,
65 dropout=0, init_type=InitType.XAVIER_UNIFORM, low=0, high=1,
66 mean=0, std=1, activation_type=ActivationType.NONE,
67 fan_mode=FAN_MODE.FAN_IN, negative_slope=0,
68 model_mode=ModeType.TRAIN):
69 super(Embedding, self).__init__()
70 self.logger = Logger(config)
71 self.dropout = torch.nn.Dropout(p=dropout)
72 self.mode = mode
73 if self.mode == EmbeddingProcessType.FLAT:
74 self.embedding = torch.nn.Embedding(
75 len(dict_map), embedding_dim, padding_idx=padding_idx)
76 else:
77 self.embedding = torch.nn.EmbeddingBag(
78 len(dict_map), embedding_dim, mode=mode)
79 embedding_lookup_table = init_tensor(
80 tensor=torch.empty(len(dict_map), embedding_dim),
81 init_type=init_type, low=low, high=high, mean=mean, std=std,
82 activation_type=activation_type, fan_mode=fan_mode,
83 negative_slope=negative_slope)
84 if model_mode == ModeType.TRAIN and \
85 pretrained_embedding_file is not None and \
86 pretrained_embedding_file != "":
87 self.load_pretrained_embedding(
88 embedding_lookup_table, dict_map, embedding_dim, name,
89 pretrained_embedding_file)
90 if padding_idx is not None:
91 embedding_lookup_table[padding_idx] = 0.0
92 self.embedding.weight.data.copy_(embedding_lookup_table)
93
94 def forward(self, vocab_ids, offset=None):
95 if self.mode == EmbeddingProcessType.FLAT:
96 embedding = self.embedding(vocab_ids)
97 else:
98 embedding = self.embedding(vocab_ids, offset)
99 return self.dropout(embedding)
100
101 def load_pretrained_embedding(
102 self, embedding_lookup_table, dict_map, embedding_dim, name,
103 pretrained_embedding_file):
104 self.logger.warn(
105 "Load %s embedding from %s" % (name, pretrained_embedding_file))
106 with open(pretrained_embedding_file) as fin:
107 num_pretrained = 0
108 for line in fin:
109 data = line.strip().split(' ')
110 # Check embedding info
111 if len(data) == 2:
112 assert int(data[1]) == embedding_dim, \
113 "Pretrained embedding dim not matching: %s, %d" % (
114 data[1], embedding_dim)
115 continue
116 if data[0] not in dict_map:
117 continue
118 embedding = torch.FloatTensor([float(i) for i in data[1:]])
119 embedding_lookup_table[dict_map[data[0]]] = embedding

Callers 3

__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected