MCPcopy
hub / github.com/geekcomputers/Python / STL10Dataset

Class STL10Dataset

ML/src/python/neuralforge/data/datasets.py:92–115  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

90 return self.dataset[idx]
91
92class STL10Dataset:
93 def __init__(self, root='./data', split='train', transform=None, download=True):
94 if transform is None:
95 if split == 'train':
96 transform = transforms.Compose([
97 transforms.RandomCrop(96, padding=12),
98 transforms.RandomHorizontalFlip(),
99 transforms.ToTensor(),
100 transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
101 ])
102 else:
103 transform = transforms.Compose([
104 transforms.ToTensor(),
105 transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
106 ])
107
108 self.dataset = datasets.STL10(root=root, split=split, transform=transform, download=download)
109 self.classes = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
110
111 def __len__(self):
112 return len(self.dataset)
113
114 def __getitem__(self, idx):
115 return self.dataset[idx]
116
117def get_dataset(name='cifar10', root='./data', train=True, download=True):
118 name = name.lower()

Callers 1

get_datasetFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected