| 16 | import torch |
| 17 | # TODO - SHOULD BE FURTHER IMPROVED. |
| 18 | class UniversalPrompting(): |
| 19 | def __init__(self, text_tokenizer, |
| 20 | special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), |
| 21 | max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1): |
| 22 | """ |
| 23 | :param text_tokenizer: original text tokenizer |
| 24 | """ |
| 25 | self.text_tokenizer = text_tokenizer |
| 26 | self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| 27 | self.text_tokenizer.add_tokens(list(special_tokens)) |
| 28 | self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in |
| 29 | special_tokens} |
| 30 | self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id]) |
| 31 | self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id]) |
| 32 | self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id]) |
| 33 | # plus 1 because at this time we add a task token before |
| 34 | self.max_text_len = max_text_len + 1 |
| 35 | self.pad_id = self.text_tokenizer.convert_tokens_to_ids('[PAD]') |
| 36 | self.ignore_id = ignore_id |
| 37 | self.cond_dropout_prob = cond_dropout_prob |
| 38 | |
| 39 | def t2i_prompt(self, text_ids, image_ids, labels): |
| 40 | |
| 41 | device = image_ids.device |
| 42 | sequence_ids = [] |
| 43 | attention_masks = [] |
| 44 | label_ids = [] |
| 45 | probs = torch.rand(len(text_ids)) |
| 46 | for i in range(len(text_ids)): |
| 47 | |
| 48 | if len(text_ids[i]) == 0: |
| 49 | text_ids[i] = [self.text_tokenizer.bos_token_id] |
| 50 | elif text_ids[i][0] != self.text_tokenizer.bos_token_id: |
| 51 | text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i] |
| 52 | |
| 53 | temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id] |
| 54 | |
| 55 | # randomly dropout text condition |
| 56 | if probs[i] < self.cond_dropout_prob: |
| 57 | temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id] |
| 58 | |
| 59 | if self.max_text_len >= len(temp_ids): |
| 60 | temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids |
| 61 | temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3) |
| 62 | else: |
| 63 | # should add the eos token |
| 64 | temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id] |
| 65 | temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens |
| 66 | |
| 67 | # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi] |
| 68 | temp_label_ids = torch.cat([ |
| 69 | # should we predict text tokens when doing image reconstruction? |
| 70 | torch.tensor(temp_ids).to(device), |
| 71 | self.sptids_dict['<|soi|>'].to(device), |
| 72 | labels[i], |
| 73 | self.sptids_dict['<|eoi|>'].to(device) |
| 74 | ], dim=0) |
| 75 |
no outgoing calls
no test coverage detected