| 6 | |
| 7 | |
| 8 | class TextImageDataset(torch.utils.data.Dataset): |
| 9 | def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False): |
| 10 | self.steps_per_epoch = steps_per_epoch |
| 11 | metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv")) |
| 12 | self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]] |
| 13 | self.text = metadata["text"].to_list() |
| 14 | self.height = height |
| 15 | self.width = width |
| 16 | self.image_processor = transforms.Compose( |
| 17 | [ |
| 18 | transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)), |
| 19 | transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), |
| 20 | transforms.ToTensor(), |
| 21 | transforms.Normalize([0.5], [0.5]), |
| 22 | ] |
| 23 | ) |
| 24 | |
| 25 | |
| 26 | def __getitem__(self, index): |
| 27 | data_id = torch.randint(0, len(self.path), (1,))[0] |
| 28 | data_id = (data_id + index) % len(self.path) # For fixed seed. |
| 29 | text = self.text[data_id] |
| 30 | image = Image.open(self.path[data_id]).convert("RGB") |
| 31 | target_height, target_width = self.height, self.width |
| 32 | width, height = image.size |
| 33 | scale = max(target_width / width, target_height / height) |
| 34 | shape = [round(height*scale),round(width*scale)] |
| 35 | image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR) |
| 36 | image = self.image_processor(image) |
| 37 | return {"text": text, "image": image} |
| 38 | |
| 39 | |
| 40 | def __len__(self): |
| 41 | return self.steps_per_epoch |
no outgoing calls
no test coverage detected