| 164 | return self.dataset[idx] |
| 165 | |
| 166 | class TinyImageNetDataset: |
| 167 | def __init__(self, root='./data', train=True, transform=None, download=True): |
| 168 | if transform is None: |
| 169 | if train: |
| 170 | transform = transforms.Compose([ |
| 171 | transforms.RandomCrop(64, padding=8), |
| 172 | transforms.RandomHorizontalFlip(), |
| 173 | transforms.ToTensor(), |
| 174 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 175 | ]) |
| 176 | else: |
| 177 | transform = transforms.Compose([ |
| 178 | transforms.ToTensor(), |
| 179 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 180 | ]) |
| 181 | |
| 182 | import zipfile |
| 183 | import urllib.request |
| 184 | |
| 185 | data_dir = os.path.join(root, 'tiny-imagenet-200') |
| 186 | if download and not os.path.exists(data_dir): |
| 187 | print("Downloading Tiny ImageNet (237 MB)...") |
| 188 | url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip' |
| 189 | zip_path = os.path.join(root, 'tiny-imagenet-200.zip') |
| 190 | |
| 191 | try: |
| 192 | urllib.request.urlretrieve(url, zip_path) |
| 193 | print("Extracting...") |
| 194 | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| 195 | zip_ref.extractall(root) |
| 196 | os.remove(zip_path) |
| 197 | except Exception as e: |
| 198 | print(f"Download failed: {e}") |
| 199 | print("Please download manually from: http://cs231n.stanford.edu/tiny-imagenet-200.zip") |
| 200 | |
| 201 | split = 'train' if train else 'val' |
| 202 | self.dataset = datasets.ImageFolder(os.path.join(data_dir, split), transform=transform) |
| 203 | |
| 204 | def __len__(self): |
| 205 | return len(self.dataset) |
| 206 | |
| 207 | def __getitem__(self, idx): |
| 208 | return self.dataset[idx] |
| 209 | |
| 210 | class Food101Dataset: |
| 211 | def __init__(self, root='./data', split='train', transform=None, download=True): |