(self, query, img_size=256)
| 152 | raise ValueError('') |
| 153 | |
| 154 | def parse_query(self, query, img_size=256): |
| 155 | text_buffer = [] |
| 156 | ret = [] |
| 157 | for part in query.split(' '): |
| 158 | if part in self.command_tokens: |
| 159 | if len(text_buffer) > 0: |
| 160 | # dump text ids |
| 161 | ret.extend(self.EncodeAsIds(' '.join(text_buffer))) |
| 162 | text_buffer = [] |
| 163 | if part == '[MASK]': |
| 164 | ret.append(-1) |
| 165 | else: |
| 166 | ret.append(self.command_tokens[part]) |
| 167 | elif part.startswith('[MASK]*'): # special lang *N |
| 168 | c = int(part[7:]) |
| 169 | assert c > 0 |
| 170 | if len(text_buffer) > 0: |
| 171 | # dump text ids |
| 172 | ret.extend(self.EncodeAsIds(' '.join(text_buffer))) |
| 173 | text_buffer = [] |
| 174 | ret.extend([-1] * c) |
| 175 | elif part.startswith('[Image'): # [Image*N]path |
| 176 | c = part[6:] |
| 177 | assert len(c) > 0 |
| 178 | num_codes, img_path = c.split(']') |
| 179 | if num_codes == '': |
| 180 | num_codes = 1024 |
| 181 | else: |
| 182 | num_codes = int(num_codes) |
| 183 | |
| 184 | raw_img = self.img_tokenizer.read_img(img_path, img_size=img_size) |
| 185 | img_codes = self.img_tokenizer.EncodeAsIds(raw_img) # [1, 32*32] |
| 186 | img_codes[0, num_codes:] = -1 |
| 187 | img_codes = img_codes[0].tolist() |
| 188 | ret.extend(img_codes) |
| 189 | else: |
| 190 | text_buffer.append(part) |
| 191 | |
| 192 | if len(text_buffer) > 0: |
| 193 | # dump text ids |
| 194 | ret.extend(self.EncodeAsIds(' '.join(text_buffer))) |
| 195 | text_buffer = [] |
| 196 | return ret |
| 197 | |
| 198 | def get_tokenizer(args=None): |
| 199 | if not hasattr(get_tokenizer, 'tokenizer'): |
no test coverage detected