| 124 | |
| 125 | |
| 126 | class VisionTextProcessor(object): |
| 127 | @staticmethod |
| 128 | def get_default_config(updates=None): |
| 129 | config = ConfigDict() |
| 130 | config.fields_from_example = '' |
| 131 | config.subfield_separator = ' ' |
| 132 | config.add_bos_token = True |
| 133 | config.add_eos_token = True |
| 134 | config.prepend_text = '' |
| 135 | config.fields_index = -1 |
| 136 | config.eof_token = 8192 # denotes end of each frame for video generation |
| 137 | config.eov_token = 8193 # denotes end of vision generation |
| 138 | config.n_tokens_per_frame = 256 # 16 x 16 VQ codes |
| 139 | config.max_n_frames = -1 |
| 140 | if updates is not None: |
| 141 | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| 142 | return config |
| 143 | |
| 144 | def __init__(self, config, tokenizer): |
| 145 | self.config = self.get_default_config(config) |
| 146 | assert self.config.fields_from_example != '', ( |
| 147 | 'fields_from_example must be specified.' |
| 148 | ) |
| 149 | self.tokenizer = tokenizer |
| 150 | self.vision_start = tokenizer.encode('<vision>') |
| 151 | self.vision_end = tokenizer.encode('</vision>') |
| 152 | |
| 153 | def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=True): |
| 154 | if has_aux: |
| 155 | example, *aux = example |
| 156 | else: |
| 157 | aux = tuple() |
| 158 | rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number |
| 159 | token_buffer = [] |
| 160 | loss_mask_buffer = [] |
| 161 | vision_mask = [] |
| 162 | |
| 163 | fields = example[self.config.fields_from_example] |
| 164 | if isinstance(fields, (tuple, list)): |
| 165 | if self.config.fields_index >= 0: |
| 166 | fields = fields[self.config.fields_index] |
| 167 | else: |
| 168 | # seed based on line number |
| 169 | fields = rand_state.choice(fields) |
| 170 | fields = fields.split(',') |
| 171 | |
| 172 | if add_bos_token and self.config.add_bos_token: |
| 173 | token_buffer.append(self.tokenizer.bos_token_id) |
| 174 | loss_mask_buffer.append(0.0) |
| 175 | vision_mask.append(False) |
| 176 | |
| 177 | for i, field in enumerate(fields): |
| 178 | if field.startswith('[') and field.endswith(']'): |
| 179 | # No loss for this field. |
| 180 | field = field[1:-1] |
| 181 | mask = 0.0 |
| 182 | else: |
| 183 | mask = 1.0 |