| 27 | raise NotImplementedError("Prompt function " + self.spell_func) |
| 28 | |
| 29 | def init_embedding(self, word_embeddings=None, task_tokens=None): |
| 30 | num_words = 5000 |
| 31 | with torch.no_grad(): |
| 32 | for i in range(self.spell_length): |
| 33 | rand_token = random.randrange(num_words) |
| 34 | if task_tokens is None: |
| 35 | target_embedding = word_embeddings[rand_token] |
| 36 | else: |
| 37 | word_embedding = word_embeddings[rand_token] |
| 38 | task_token = random.choice(task_tokens) |
| 39 | task_embedding = word_embeddings[task_token] |
| 40 | ratio = random.random() |
| 41 | target_embedding = word_embedding * ratio + task_embedding * (1 - ratio) |
| 42 | self.spell_embeddings.weight.data[i] = target_embedding |
| 43 | |
| 44 | def forward(self): |
| 45 | prompt_embeds = self.spell_embeddings.weight.unsqueeze(0) |