| 21 | from .vqvae_tokenizer import VQVAETokenizer, sqrt_int |
| 22 | |
| 23 | class UnifiedTokenizer(object): |
| 24 | def __init__(self, img_tokenizer_path, device, img_tokenizer_num_tokens=None): |
| 25 | self.device = device |
| 26 | if img_tokenizer_path is None and img_tokenizer_num_tokens is not None: |
| 27 | # pretraining but only know the vocab size of VQVAE, which is developing fast |
| 28 | self.img_tokenizer = FakeTokenizer(img_tokenizer_num_tokens) |
| 29 | else: |
| 30 | self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device) |
| 31 | self.txt_tokenizer = from_pretrained() |
| 32 | self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens |
| 33 | self.raw_command_tokens = [ |
| 34 | ('[PAD]', 0), |
| 35 | ('[BOI1]', 1), # Begin |
| 36 | ('[BOI2]', 2), |
| 37 | ('[BOI3]', 3), |
| 38 | ('[EOI1]', 4), # End |
| 39 | ('[EOI2]', 5), |
| 40 | ('[EOI3]', 6), |
| 41 | ('[ROI1]', 7), # Reference |
| 42 | ('[ROI2]', 8), |
| 43 | ('[ROI3]', 9), |
| 44 | ('[SEP]', 10), |
| 45 | ('[MASK]', 11), |
| 46 | ('[CLS]', 12), |
| 47 | ('[ENC]', 13), |
| 48 | ('[TINY]', 14), # 8 * 8 |
| 49 | ('[SMALL]', 15), # 16 * 16 |
| 50 | ('[BASE]', 16), # 32 * 32 |
| 51 | ('[BIG]', 17), # 64 * 64 |
| 52 | ('[POS0]', 18), # 58210 |
| 53 | ('[POS1]', 19), |
| 54 | ('[POS2]', 20), |
| 55 | ('[POS3]', 21), |
| 56 | ('[POS4]', 22), |
| 57 | ('[POS5]', 23), |
| 58 | ('[POS6]', 24), |
| 59 | ('[POS7]', 25), |
| 60 | ('[POS8]', 26) |
| 61 | # Please leave the ``size tokens'' at the back of command tokens |
| 62 | ] |
| 63 | self.command_tokens = { |
| 64 | k: v + self.num_tokens |
| 65 | for k, v in self.raw_command_tokens |
| 66 | } |
| 67 | self.num_tokens += len(self.raw_command_tokens) |
| 68 | |
| 69 | def __getitem__(self, command_token): |
| 70 | return self.command_tokens[command_token] |
| 71 | |
| 72 | def __len__(self): |
| 73 | """total number of tokens""" |
| 74 | return self.num_tokens |
| 75 | |
| 76 | def __call__(self, inputs, process_fn=None): |
| 77 | """run preprocessing and encode inputs as Ids |
| 78 | CANNOT contain command tokens""" |
| 79 | if isinstance(inputs, torch.Tensor): # image |
| 80 | if len(inputs.shape) == 3: |