| 8 | |
| 9 | |
| 10 | class ImageTextData(object): |
| 11 | |
| 12 | def __init__(self, dataset, root, preprocess, prompt='a picture of a'): |
| 13 | if type(dataset) is int: |
| 14 | dataset = self._DATA_FOLDER[dataset] |
| 15 | dataset = os.path.join(root, dataset) |
| 16 | if dataset == 'imagenet-r': |
| 17 | data = datasets.ImageFolder( |
| 18 | 'imagenet-r', transform=self._TRANSFORM) |
| 19 | labels = open('imagenetr_labels.txt').read().splitlines() |
| 20 | labels = [x.split(',')[1].strip() for x in labels] |
| 21 | else: |
| 22 | data = datasets.ImageFolder(dataset, transform=self._TRANSFORM) |
| 23 | labels = data.classes |
| 24 | self.data = data |
| 25 | self.labels = labels |
| 26 | if prompt: |
| 27 | self.labels = [prompt + ' ' + x for x in self.labels] |
| 28 | |
| 29 | self.preprocess = preprocess |
| 30 | self.text = clip.tokenize(self.labels) |
| 31 | |
| 32 | def __getitem__(self, index): |
| 33 | image, label = self.data.imgs[index] |
| 34 | if self.preprocess is not None: |
| 35 | image = self.preprocess(Image.open(image)) |
| 36 | text_enc = self.text[label] |
| 37 | return image, text_enc, label |
| 38 | |
| 39 | def __len__(self): |
| 40 | return len(self.data) |
| 41 | |
| 42 | @staticmethod |
| 43 | def get_data_name_by_index(index): |
| 44 | name = ImageTextData._DATA_FOLDER[index] |
| 45 | name = name.replace('/', '_') |
| 46 | return name |
| 47 | |
| 48 | _DATA_FOLDER = [ |
| 49 | 'dataset/OfficeHome/Art', |
| 50 | 'dataset/OfficeHome/Clipart', |
| 51 | 'dataset/OfficeHome/Product', |
| 52 | 'dataset/OfficeHome/RealWorld', |
| 53 | |
| 54 | 'dataset/office31/amazon', # 4 |
| 55 | 'dataset/office31/webcam', |
| 56 | 'dataset/office31/dslr', |
| 57 | |
| 58 | 'dataset/VLCS/Caltech101', # 7 |
| 59 | 'dataset/VLCS/LabelMe', |
| 60 | 'dataset/VLCS/SUN09', |
| 61 | 'dataset/VLCS/VOC2007', |
| 62 | |
| 63 | 'dataset/PACS/kfold/art_painting', # 11 |
| 64 | 'dataset/PACS/kfold/cartoon', |
| 65 | 'dataset/PACS/kfold/photo', |
| 66 | 'dataset/PACS/kfold/sketch', |
| 67 | |