MCPcopy Index your code
hub / github.com/geekcomputers/Python / CIFAR10Dataset

Class CIFAR10Dataset

ML/src/python/neuralforge/data/datasets.py:7–30  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

5from typing import Optional, Callable
6
7class CIFAR10Dataset:
8 def __init__(self, root='./data', train=True, transform=None, download=True):
9 if transform is None:
10 if train:
11 transform = transforms.Compose([
12 transforms.RandomCrop(32, padding=4),
13 transforms.RandomHorizontalFlip(),
14 transforms.ToTensor(),
15 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
16 ])
17 else:
18 transform = transforms.Compose([
19 transforms.ToTensor(),
20 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
21 ])
22
23 self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)
24 self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
25
26 def __len__(self):
27 return len(self.dataset)
28
29 def __getitem__(self, idx):
30 return self.dataset[idx]
31
32class CIFAR100Dataset:
33 def __init__(self, root='./data', train=True, transform=None, download=True):

Callers 1

get_datasetFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected