MCPcopy
hub / github.com/HobbitLong/SupContrast / set_loader

Function set_loader

main_ce.py:118–167  ·  view source on GitHub ↗
(opt)

Source from the content-addressed store, hash-verified

116
117
118def set_loader(opt):
119 # construct data loader
120 if opt.dataset == 'cifar10':
121 mean = (0.4914, 0.4822, 0.4465)
122 std = (0.2023, 0.1994, 0.2010)
123 elif opt.dataset == 'cifar100':
124 mean = (0.5071, 0.4867, 0.4408)
125 std = (0.2675, 0.2565, 0.2761)
126 else:
127 raise ValueError('dataset not supported: {}'.format(opt.dataset))
128 normalize = transforms.Normalize(mean=mean, std=std)
129
130 train_transform = transforms.Compose([
131 transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
132 transforms.RandomHorizontalFlip(),
133 transforms.ToTensor(),
134 normalize,
135 ])
136
137 val_transform = transforms.Compose([
138 transforms.ToTensor(),
139 normalize,
140 ])
141
142 if opt.dataset == 'cifar10':
143 train_dataset = datasets.CIFAR10(root=opt.data_folder,
144 transform=train_transform,
145 download=True)
146 val_dataset = datasets.CIFAR10(root=opt.data_folder,
147 train=False,
148 transform=val_transform)
149 elif opt.dataset == 'cifar100':
150 train_dataset = datasets.CIFAR100(root=opt.data_folder,
151 transform=train_transform,
152 download=True)
153 val_dataset = datasets.CIFAR100(root=opt.data_folder,
154 train=False,
155 transform=val_transform)
156 else:
157 raise ValueError(opt.dataset)
158
159 train_sampler = None
160 train_loader = torch.utils.data.DataLoader(
161 train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
162 num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
163 val_loader = torch.utils.data.DataLoader(
164 val_dataset, batch_size=256, shuffle=False,
165 num_workers=8, pin_memory=True)
166
167 return train_loader, val_loader
168
169
170def set_model(opt):

Callers 2

mainFunction · 0.90
mainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected