| 29 | |
| 30 | # Custom dataset class |
| 31 | class CustomDataset(Dataset): |
| 32 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config): |
| 33 | self.questions = questions |
| 34 | self.image_folder = image_folder |
| 35 | self.tokenizer = tokenizer |
| 36 | self.image_processor = image_processor |
| 37 | self.model_config = model_config |
| 38 | |
| 39 | def __getitem__(self, index): |
| 40 | line = self.questions[index] |
| 41 | image_file = line["image"] |
| 42 | qs = line["text"] |
| 43 | if self.model_config.mm_use_im_start_end: |
| 44 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs |
| 45 | else: |
| 46 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs |
| 47 | |
| 48 | conv = conv_templates[args.conv_mode].copy() |
| 49 | conv.append_message(conv.roles[0], qs) |
| 50 | conv.append_message(conv.roles[1], None) |
| 51 | prompt = conv.get_prompt() |
| 52 | |
| 53 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') |
| 54 | image_tensor = process_images([image], self.image_processor, self.model_config)[0] |
| 55 | |
| 56 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') |
| 57 | |
| 58 | return input_ids, image_tensor, image.size |
| 59 | |
| 60 | def __len__(self): |
| 61 | return len(self.questions) |
| 62 | |
| 63 | |
| 64 | def collate_fn(batch): |
no outgoing calls
no test coverage detected