MCPcopy
hub / github.com/jindongwang/transferlearning / ImageTextData

Class ImageTextData

code/clip/data/data_loader.py:10–93  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

8
9
10class ImageTextData(object):
11
12 def __init__(self, dataset, root, preprocess, prompt='a picture of a'):
13 if type(dataset) is int:
14 dataset = self._DATA_FOLDER[dataset]
15 dataset = os.path.join(root, dataset)
16 if dataset == 'imagenet-r':
17 data = datasets.ImageFolder(
18 'imagenet-r', transform=self._TRANSFORM)
19 labels = open('imagenetr_labels.txt').read().splitlines()
20 labels = [x.split(',')[1].strip() for x in labels]
21 else:
22 data = datasets.ImageFolder(dataset, transform=self._TRANSFORM)
23 labels = data.classes
24 self.data = data
25 self.labels = labels
26 if prompt:
27 self.labels = [prompt + ' ' + x for x in self.labels]
28
29 self.preprocess = preprocess
30 self.text = clip.tokenize(self.labels)
31
32 def __getitem__(self, index):
33 image, label = self.data.imgs[index]
34 if self.preprocess is not None:
35 image = self.preprocess(Image.open(image))
36 text_enc = self.text[label]
37 return image, text_enc, label
38
39 def __len__(self):
40 return len(self.data)
41
42 @staticmethod
43 def get_data_name_by_index(index):
44 name = ImageTextData._DATA_FOLDER[index]
45 name = name.replace('/', '_')
46 return name
47
48 _DATA_FOLDER = [
49 'dataset/OfficeHome/Art',
50 'dataset/OfficeHome/Clipart',
51 'dataset/OfficeHome/Product',
52 'dataset/OfficeHome/RealWorld',
53
54 'dataset/office31/amazon', # 4
55 'dataset/office31/webcam',
56 'dataset/office31/dslr',
57
58 'dataset/VLCS/Caltech101', # 7
59 'dataset/VLCS/LabelMe',
60 'dataset/VLCS/SUN09',
61 'dataset/VLCS/VOC2007',
62
63 'dataset/PACS/kfold/art_painting', # 11
64 'dataset/PACS/kfold/cartoon',
65 'dataset/PACS/kfold/photo',
66 'dataset/PACS/kfold/sketch',
67

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected