| 64 | return image, label |
| 65 | |
| 66 | class SyntheticDataset(Dataset): |
| 67 | def __init__( |
| 68 | self, |
| 69 | num_samples: int = 10000, |
| 70 | num_classes: int = 10, |
| 71 | image_size: int = 224, |
| 72 | channels: int = 3 |
| 73 | ): |
| 74 | self.num_samples = num_samples |
| 75 | self.num_classes = num_classes |
| 76 | self.image_size = image_size |
| 77 | self.channels = channels |
| 78 | |
| 79 | def __len__(self) -> int: |
| 80 | return self.num_samples |
| 81 | |
| 82 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: |
| 83 | image = torch.randn(self.channels, self.image_size, self.image_size) |
| 84 | label = idx % self.num_classes |
| 85 | return image, label |
| 86 | |
| 87 | class MemoryDataset(Dataset): |
| 88 | def __init__(self, data: torch.Tensor, labels: torch.Tensor): |