| 60 | |
| 61 | |
| 62 | class Embedding(torch.nn.Module): |
| 63 | def __init__(self, dict_map, embedding_dim, name, config, padding_idx=None, |
| 64 | pretrained_embedding_file=None, mode=EmbeddingProcessType.FLAT, |
| 65 | dropout=0, init_type=InitType.XAVIER_UNIFORM, low=0, high=1, |
| 66 | mean=0, std=1, activation_type=ActivationType.NONE, |
| 67 | fan_mode=FAN_MODE.FAN_IN, negative_slope=0, |
| 68 | model_mode=ModeType.TRAIN): |
| 69 | super(Embedding, self).__init__() |
| 70 | self.logger = Logger(config) |
| 71 | self.dropout = torch.nn.Dropout(p=dropout) |
| 72 | self.mode = mode |
| 73 | if self.mode == EmbeddingProcessType.FLAT: |
| 74 | self.embedding = torch.nn.Embedding( |
| 75 | len(dict_map), embedding_dim, padding_idx=padding_idx) |
| 76 | else: |
| 77 | self.embedding = torch.nn.EmbeddingBag( |
| 78 | len(dict_map), embedding_dim, mode=mode) |
| 79 | embedding_lookup_table = init_tensor( |
| 80 | tensor=torch.empty(len(dict_map), embedding_dim), |
| 81 | init_type=init_type, low=low, high=high, mean=mean, std=std, |
| 82 | activation_type=activation_type, fan_mode=fan_mode, |
| 83 | negative_slope=negative_slope) |
| 84 | if model_mode == ModeType.TRAIN and \ |
| 85 | pretrained_embedding_file is not None and \ |
| 86 | pretrained_embedding_file != "": |
| 87 | self.load_pretrained_embedding( |
| 88 | embedding_lookup_table, dict_map, embedding_dim, name, |
| 89 | pretrained_embedding_file) |
| 90 | if padding_idx is not None: |
| 91 | embedding_lookup_table[padding_idx] = 0.0 |
| 92 | self.embedding.weight.data.copy_(embedding_lookup_table) |
| 93 | |
| 94 | def forward(self, vocab_ids, offset=None): |
| 95 | if self.mode == EmbeddingProcessType.FLAT: |
| 96 | embedding = self.embedding(vocab_ids) |
| 97 | else: |
| 98 | embedding = self.embedding(vocab_ids, offset) |
| 99 | return self.dropout(embedding) |
| 100 | |
| 101 | def load_pretrained_embedding( |
| 102 | self, embedding_lookup_table, dict_map, embedding_dim, name, |
| 103 | pretrained_embedding_file): |
| 104 | self.logger.warn( |
| 105 | "Load %s embedding from %s" % (name, pretrained_embedding_file)) |
| 106 | with open(pretrained_embedding_file) as fin: |
| 107 | num_pretrained = 0 |
| 108 | for line in fin: |
| 109 | data = line.strip().split(' ') |
| 110 | # Check embedding info |
| 111 | if len(data) == 2: |
| 112 | assert int(data[1]) == embedding_dim, \ |
| 113 | "Pretrained embedding dim not matching: %s, %d" % ( |
| 114 | data[1], embedding_dim) |
| 115 | continue |
| 116 | if data[0] not in dict_map: |
| 117 | continue |
| 118 | embedding = torch.FloatTensor([float(i) for i in data[1:]]) |
| 119 | embedding_lookup_table[dict_map[data[0]]] = embedding |