| 150 | setattr(self, module_name, module) |
| 151 | |
| 152 | def embedding(self, |
| 153 | input_ids, |
| 154 | position_ids=None, |
| 155 | token_type_ids=None, |
| 156 | prompt_embedding_table=None, |
| 157 | prompt_tasks=None, |
| 158 | prompt_vocab_size=None): |
| 159 | # position_ids and token_type_ids are provided inputs |
| 160 | # and should not be formulated deterministically |
| 161 | |
| 162 | args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size |
| 163 | ] if prompt_embedding_table is not None else [] |
| 164 | |
| 165 | x = self.vocab_embedding(input_ids, *args) * self.embedding_scale |
| 166 | self.register_network_output('word_embeddings', x) |
| 167 | |
| 168 | if self.position_embedding: |
| 169 | pos_emb = self.position_embedding(position_ids) |
| 170 | self.register_network_output('position_embeddings', pos_emb) |
| 171 | x = x + pos_emb |
| 172 | if self.token_type_embedding: |
| 173 | x = x + self.token_type_embedding(token_type_ids) |
| 174 | |
| 175 | if self.ln_embed: |
| 176 | x = self.ln_embed(x) |
| 177 | |
| 178 | return x |
| 179 | |
| 180 | |
| 181 | class EncoderLayer(Module): |