(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None)
| 183 | |
| 184 | |
| 185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): |
| 186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] |
| 187 | |
| 188 | def insert_separator(X, sep): |
| 189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
| 190 | |
| 191 | input_ids = [] |
| 192 | offset = 0 |
| 193 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
| 194 | offset = 1 |
| 195 | input_ids.append(prompt_chunks[0][0]) |
| 196 | |
| 197 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
| 198 | input_ids.extend(x[offset:]) |
| 199 | |
| 200 | if return_tensors is not None: |
| 201 | if return_tensors == 'pt': |
| 202 | return torch.tensor(input_ids, dtype=torch.long) |
| 203 | raise ValueError(f'Unsupported tensor type: {return_tensors}') |
| 204 | return input_ids |
| 205 | |
| 206 | |
| 207 | def get_model_name_from_path(model_path): |
no test coverage detected