MCPcopy
hub / github.com/facebookresearch/mmf / __init__

Method __init__

pythia/utils/vocab.py:250–313  ·  view source on GitHub ↗

Use this vocab class when you have a custom vocabulary class but you want to use pretrained embedding vectos for it. This will only load the vectors which intersect with your vocabulary. Use the embedding_name specified in torchtext's pretrained aliases: ['charngram.1

(self, vocab_file, embedding_name, *args, **kwargs)

Source from the content-addressed store, hash-verified

248
249class IntersectedVocab(BaseVocab):
250 def __init__(self, vocab_file, embedding_name, *args, **kwargs):
251 """Use this vocab class when you have a custom vocabulary class but you
252 want to use pretrained embedding vectos for it. This will only load
253 the vectors which intersect with your vocabulary. Use the
254 embedding_name specified in torchtext's pretrained aliases:
255 ['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d',
256 'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d',
257 'glove.twitter.27B.50d', 'glove.twitter.27B.100d',
258 'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d',
259 'glove.6B.200d', 'glove.6B.300d']
260
261 Parameters
262 ----------
263 vocab_file : str
264 Vocabulary file containing list of words with one word per line
265 which will be used to collect vectors
266 embedding_name : str
267 Embedding name picked up from the list of the pretrained aliases
268 mentioned above
269 """
270 super(IntersectedVocab, self).__init__(vocab_file, *args, **kwargs)
271
272 self.type = "intersected"
273
274 name = embedding_name.split(".")[0]
275 dim = embedding_name.split(".")[2][:-1]
276 middle = embedding_name.split(".")[1]
277
278 class_name = EMBEDDING_NAME_CLASS_MAPPING[name]
279
280 if not hasattr(vocab, class_name):
281 from pythia.common.registry import registry
282
283 writer = registry.get("writer")
284 error = "Unknown embedding type: %s" % name, "error"
285 if writer is not None:
286 writer.write(error, "error")
287 raise RuntimeError(error)
288
289 params = [middle]
290
291 if name == "glove":
292 params.append(int(dim))
293
294 vector_cache = os.path.join(get_pythia_root(), ".vector_cache")
295 embedding = getattr(vocab, class_name)(*params, cache=vector_cache)
296
297 self.vectors = torch.empty(
298 (self.get_size(), len(embedding.vectors[0])), dtype=torch.float
299 )
300
301 self.embedding_dim = len(embedding.vectors[0])
302
303 for i in range(0, 4):
304 self.vectors[i] = torch.ones_like(self.vectors[i]) * 0.1 * i
305
306 for i in range(4, self.get_size()):
307 word = self.itos[i]

Callers

nothing calls this directly

Calls 5

get_pythia_rootFunction · 0.90
getMethod · 0.80
writeMethod · 0.80
__init__Method · 0.45
get_sizeMethod · 0.45

Tested by

no test coverage detected