| 99 | return self.dropout(embedding) |
| 100 | |
| 101 | def load_pretrained_embedding( |
| 102 | self, embedding_lookup_table, dict_map, embedding_dim, name, |
| 103 | pretrained_embedding_file): |
| 104 | self.logger.warn( |
| 105 | "Load %s embedding from %s" % (name, pretrained_embedding_file)) |
| 106 | with open(pretrained_embedding_file) as fin: |
| 107 | num_pretrained = 0 |
| 108 | for line in fin: |
| 109 | data = line.strip().split(' ') |
| 110 | # Check embedding info |
| 111 | if len(data) == 2: |
| 112 | assert int(data[1]) == embedding_dim, \ |
| 113 | "Pretrained embedding dim not matching: %s, %d" % ( |
| 114 | data[1], embedding_dim) |
| 115 | continue |
| 116 | if data[0] not in dict_map: |
| 117 | continue |
| 118 | embedding = torch.FloatTensor([float(i) for i in data[1:]]) |
| 119 | embedding_lookup_table[dict_map[data[0]]] = embedding |
| 120 | num_pretrained += 1 |
| 121 | self.logger.warn( |
| 122 | "Total dict size of %s is %d" % (name, len(dict_map))) |
| 123 | self.logger.warn("Size of pretrained %s embedding is %d" % ( |
| 124 | name, num_pretrained)) |
| 125 | self.logger.warn( |
| 126 | "Size of randomly initialize %s embedding is %d" % ( |
| 127 | name, len(dict_map) - num_pretrained)) |
| 128 | |
| 129 | |
| 130 | class RegionEmbeddingType(Type): |