Use this vocab class when you have a custom vocabulary class but you want to use pretrained embedding vectos for it. This will only load the vectors which intersect with your vocabulary. Use the embedding_name specified in torchtext's pretrained aliases: ['charngram.1
(self, vocab_file, embedding_name, *args, **kwargs)
| 248 | |
| 249 | class IntersectedVocab(BaseVocab): |
| 250 | def __init__(self, vocab_file, embedding_name, *args, **kwargs): |
| 251 | """Use this vocab class when you have a custom vocabulary class but you |
| 252 | want to use pretrained embedding vectos for it. This will only load |
| 253 | the vectors which intersect with your vocabulary. Use the |
| 254 | embedding_name specified in torchtext's pretrained aliases: |
| 255 | ['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d', |
| 256 | 'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d', |
| 257 | 'glove.twitter.27B.50d', 'glove.twitter.27B.100d', |
| 258 | 'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d', |
| 259 | 'glove.6B.200d', 'glove.6B.300d'] |
| 260 | |
| 261 | Parameters |
| 262 | ---------- |
| 263 | vocab_file : str |
| 264 | Vocabulary file containing list of words with one word per line |
| 265 | which will be used to collect vectors |
| 266 | embedding_name : str |
| 267 | Embedding name picked up from the list of the pretrained aliases |
| 268 | mentioned above |
| 269 | """ |
| 270 | super(IntersectedVocab, self).__init__(vocab_file, *args, **kwargs) |
| 271 | |
| 272 | self.type = "intersected" |
| 273 | |
| 274 | name = embedding_name.split(".")[0] |
| 275 | dim = embedding_name.split(".")[2][:-1] |
| 276 | middle = embedding_name.split(".")[1] |
| 277 | |
| 278 | class_name = EMBEDDING_NAME_CLASS_MAPPING[name] |
| 279 | |
| 280 | if not hasattr(vocab, class_name): |
| 281 | from pythia.common.registry import registry |
| 282 | |
| 283 | writer = registry.get("writer") |
| 284 | error = "Unknown embedding type: %s" % name, "error" |
| 285 | if writer is not None: |
| 286 | writer.write(error, "error") |
| 287 | raise RuntimeError(error) |
| 288 | |
| 289 | params = [middle] |
| 290 | |
| 291 | if name == "glove": |
| 292 | params.append(int(dim)) |
| 293 | |
| 294 | vector_cache = os.path.join(get_pythia_root(), ".vector_cache") |
| 295 | embedding = getattr(vocab, class_name)(*params, cache=vector_cache) |
| 296 | |
| 297 | self.vectors = torch.empty( |
| 298 | (self.get_size(), len(embedding.vectors[0])), dtype=torch.float |
| 299 | ) |
| 300 | |
| 301 | self.embedding_dim = len(embedding.vectors[0]) |
| 302 | |
| 303 | for i in range(0, 4): |
| 304 | self.vectors[i] = torch.ones_like(self.vectors[i]) * 0.1 * i |
| 305 | |
| 306 | for i in range(4, self.get_size()): |
| 307 | word = self.itos[i] |
nothing calls this directly
no test coverage detected