Reference: A New Method of Region Embedding for Text Classification
| 139 | |
| 140 | |
| 141 | class RegionEmbeddingLayer(torch.nn.Module): |
| 142 | """ |
| 143 | Reference: A New Method of Region Embedding for Text Classification |
| 144 | """ |
| 145 | |
| 146 | def __init__(self, dict_map, embedding_dim, region_size, name, config, |
| 147 | padding=None, pretrained_embedding_file=None, dropout=0, |
| 148 | init_type=InitType.XAVIER_UNIFORM, low=0, high=1, mean=0, |
| 149 | std=1, fan_mode=FAN_MODE.FAN_IN, model_mode=ModeType.TRAIN, |
| 150 | region_embedding_type=RegionEmbeddingType.WC): |
| 151 | super(RegionEmbeddingLayer, self).__init__() |
| 152 | self.region_embedding_type = region_embedding_type |
| 153 | self.region_size = region_size |
| 154 | assert self.region_size % 2 == 1 |
| 155 | self.radius = int(region_size / 2) |
| 156 | self.embedding_dim = embedding_dim |
| 157 | self.embedding = Embedding( |
| 158 | dict_map, embedding_dim, "RegionWord" + name, config=config, |
| 159 | padding_idx=padding, |
| 160 | pretrained_embedding_file=pretrained_embedding_file, |
| 161 | dropout=dropout, init_type=init_type, low=low, high=high, mean=mean, |
| 162 | std=std, fan_mode=fan_mode, model_mode=model_mode) |
| 163 | self.context_embedding = Embedding( |
| 164 | dict_map, embedding_dim * region_size, "RegionContext" + name, |
| 165 | config=config, padding_idx=padding, dropout=dropout, |
| 166 | init_type=init_type, low=low, high=high, mean=mean, std=std, |
| 167 | fan_mode=fan_mode) |
| 168 | |
| 169 | def forward(self, vocab_ids): |
| 170 | seq_length = vocab_ids.size(1) |
| 171 | actual_length = vocab_ids.size(1) - self.radius * 2 |
| 172 | trim_vocab_id = vocab_ids[:, self.radius:seq_length - self.radius] |
| 173 | slice_vocabs = \ |
| 174 | [vocab_ids[:, i:i + self.region_size] for i in |
| 175 | range(actual_length)] |
| 176 | slice_vocabs = torch.cat(slice_vocabs, 1) |
| 177 | slice_vocabs = \ |
| 178 | slice_vocabs.view(-1, actual_length, self.region_size) |
| 179 | |
| 180 | if self.region_embedding_type == RegionEmbeddingType.WC: |
| 181 | vocab_embedding = self.embedding(slice_vocabs) |
| 182 | context_embedding = self.context_embedding(trim_vocab_id) |
| 183 | context_embedding = context_embedding.view( |
| 184 | -1, actual_length, self.region_size, self.embedding_dim) |
| 185 | region_embedding = vocab_embedding * context_embedding |
| 186 | region_embedding, _ = region_embedding.max(2) |
| 187 | elif self.region_embedding_type == RegionEmbeddingType.CW: |
| 188 | vocab_embedding = self.embedding(trim_vocab_id).unsqueeze(2) |
| 189 | context_embedding = self.context_embedding(slice_vocabs) |
| 190 | size = context_embedding.size() |
| 191 | context_embedding = context_embedding.view( |
| 192 | size[0], size[1], size[2], self.region_size, self.embedding_dim) |
| 193 | mask = torch.ones( |
| 194 | [self.region_size, self.region_size, self.embedding_dim]) |
| 195 | |
| 196 | for i in range(self.region_size): |
| 197 | mask[i][self.region_size - i - 1] = 0. |
| 198 | neg_mask = mask * -65500.0 |