| 299 | |
| 300 | |
| 301 | class InstanceDataRootDataset(Dataset): |
| 302 | def __init__( |
| 303 | self, |
| 304 | instance_data_root, |
| 305 | tokenizer, |
| 306 | size=512, |
| 307 | ): |
| 308 | self.size = size |
| 309 | self.tokenizer = tokenizer |
| 310 | self.instance_images_path = list(Path(instance_data_root).iterdir()) |
| 311 | |
| 312 | def __len__(self): |
| 313 | return len(self.instance_images_path) |
| 314 | |
| 315 | def __getitem__(self, index): |
| 316 | image_path = self.instance_images_path[index % len(self.instance_images_path)] |
| 317 | instance_image = Image.open(image_path) |
| 318 | rv = process_image(instance_image, self.size) |
| 319 | |
| 320 | prompt = os.path.splitext(os.path.basename(image_path))[0] |
| 321 | rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0] |
| 322 | return rv |
| 323 | |
| 324 | |
| 325 | class InstanceDataImageDataset(Dataset): |
no outgoing calls
no test coverage detected
searching dependent graphs…