MCPcopy
hub / github.com/whai362/PVT / ClassificationDataset

Class ClassificationDataset

classification/mcloader/classification.py:6–31  ·  view source on GitHub ↗

Dataset for classification.

Source from the content-addressed store, hash-verified

4
5
6class ClassificationDataset(Dataset):
7 """Dataset for classification.
8 """
9
10 def __init__(self, split='train', pipeline=None):
11 if split == 'train':
12 self.data_source = ImageNet(root='data/imagenet/train',
13 list_file='data/imagenet/meta/train.txt',
14 memcached=True,
15 mclient_path='/mnt/lustre/share/memcached_client')
16 else:
17 self.data_source = ImageNet(root='data/imagenet/val',
18 list_file='data/imagenet/meta/val.txt',
19 memcached=True,
20 mclient_path='/mnt/lustre/share/memcached_client')
21 self.pipeline = pipeline
22
23 def __len__(self):
24 return self.data_source.get_length()
25
26 def __getitem__(self, idx):
27 img, target = self.data_source.get_sample(idx)
28 if self.pipeline is not None:
29 img = self.pipeline(img)
30
31 return img, target

Callers 1

build_datasetFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected