| 342 | |
| 343 | |
| 344 | class HuggingFaceDataset(Dataset): |
| 345 | def __init__( |
| 346 | self, |
| 347 | hf_dataset, |
| 348 | tokenizer, |
| 349 | image_key, |
| 350 | prompt_key, |
| 351 | prompt_prefix=None, |
| 352 | size=512, |
| 353 | ): |
| 354 | self.size = size |
| 355 | self.image_key = image_key |
| 356 | self.prompt_key = prompt_key |
| 357 | self.tokenizer = tokenizer |
| 358 | self.hf_dataset = hf_dataset |
| 359 | self.prompt_prefix = prompt_prefix |
| 360 | |
| 361 | def __len__(self): |
| 362 | return len(self.hf_dataset) |
| 363 | |
| 364 | def __getitem__(self, index): |
| 365 | item = self.hf_dataset[index] |
| 366 | |
| 367 | rv = process_image(item[self.image_key], self.size) |
| 368 | |
| 369 | prompt = item[self.prompt_key] |
| 370 | |
| 371 | if self.prompt_prefix is not None: |
| 372 | prompt = self.prompt_prefix + prompt |
| 373 | |
| 374 | rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0] |
| 375 | |
| 376 | return rv |
| 377 | |
| 378 | |
| 379 | def process_image(image, size): |
no outgoing calls
no test coverage detected
searching dependent graphs…