| 8 | |
| 9 | |
| 10 | class ImageDataset(Dataset): |
| 11 | def __init__(self, root, transforms_=None, unaligned=False, mode="train"): |
| 12 | self.transform = transforms.Compose(transforms_) |
| 13 | self.unaligned = unaligned |
| 14 | |
| 15 | self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*")) |
| 16 | self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*")) |
| 17 | |
| 18 | def __getitem__(self, index): |
| 19 | item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) |
| 20 | |
| 21 | if self.unaligned: |
| 22 | item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])) |
| 23 | else: |
| 24 | item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)])) |
| 25 | |
| 26 | return {"A": item_A, "B": item_B} |
| 27 | |
| 28 | def __len__(self): |
| 29 | return max(len(self.files_A), len(self.files_B)) |