MCPcopy Index your code
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / RegionEmbeddingLayer

Class RegionEmbeddingLayer

model/embedding.py:141–211  ·  view source on GitHub ↗

Reference: A New Method of Region Embedding for Text Classification

Source from the content-addressed store, hash-verified

139
140
141class RegionEmbeddingLayer(torch.nn.Module):
142 """
143 Reference: A New Method of Region Embedding for Text Classification
144 """
145
146 def __init__(self, dict_map, embedding_dim, region_size, name, config,
147 padding=None, pretrained_embedding_file=None, dropout=0,
148 init_type=InitType.XAVIER_UNIFORM, low=0, high=1, mean=0,
149 std=1, fan_mode=FAN_MODE.FAN_IN, model_mode=ModeType.TRAIN,
150 region_embedding_type=RegionEmbeddingType.WC):
151 super(RegionEmbeddingLayer, self).__init__()
152 self.region_embedding_type = region_embedding_type
153 self.region_size = region_size
154 assert self.region_size % 2 == 1
155 self.radius = int(region_size / 2)
156 self.embedding_dim = embedding_dim
157 self.embedding = Embedding(
158 dict_map, embedding_dim, "RegionWord" + name, config=config,
159 padding_idx=padding,
160 pretrained_embedding_file=pretrained_embedding_file,
161 dropout=dropout, init_type=init_type, low=low, high=high, mean=mean,
162 std=std, fan_mode=fan_mode, model_mode=model_mode)
163 self.context_embedding = Embedding(
164 dict_map, embedding_dim * region_size, "RegionContext" + name,
165 config=config, padding_idx=padding, dropout=dropout,
166 init_type=init_type, low=low, high=high, mean=mean, std=std,
167 fan_mode=fan_mode)
168
169 def forward(self, vocab_ids):
170 seq_length = vocab_ids.size(1)
171 actual_length = vocab_ids.size(1) - self.radius * 2
172 trim_vocab_id = vocab_ids[:, self.radius:seq_length - self.radius]
173 slice_vocabs = \
174 [vocab_ids[:, i:i + self.region_size] for i in
175 range(actual_length)]
176 slice_vocabs = torch.cat(slice_vocabs, 1)
177 slice_vocabs = \
178 slice_vocabs.view(-1, actual_length, self.region_size)
179
180 if self.region_embedding_type == RegionEmbeddingType.WC:
181 vocab_embedding = self.embedding(slice_vocabs)
182 context_embedding = self.context_embedding(trim_vocab_id)
183 context_embedding = context_embedding.view(
184 -1, actual_length, self.region_size, self.embedding_dim)
185 region_embedding = vocab_embedding * context_embedding
186 region_embedding, _ = region_embedding.max(2)
187 elif self.region_embedding_type == RegionEmbeddingType.CW:
188 vocab_embedding = self.embedding(trim_vocab_id).unsqueeze(2)
189 context_embedding = self.context_embedding(slice_vocabs)
190 size = context_embedding.size()
191 context_embedding = context_embedding.view(
192 size[0], size[1], size[2], self.region_size, self.embedding_dim)
193 mask = torch.ones(
194 [self.region_size, self.region_size, self.embedding_dim])
195
196 for i in range(self.region_size):
197 mask[i][self.region_size - i - 1] = 0.
198 neg_mask = mask * -65500.0

Callers 1

__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected