Args: name (str): 'train', 'test', or 'extra'. data_dir (str): a directory containing the original {train,test,extra}_32x32.mat. shuffle (bool): shuffle the dataset.
(self, name, data_dir=None, shuffle=True)
| 22 | _Cache = {} |
| 23 | |
| 24 | def __init__(self, name, data_dir=None, shuffle=True): |
| 25 | """ |
| 26 | Args: |
| 27 | name (str): 'train', 'test', or 'extra'. |
| 28 | data_dir (str): a directory containing the original {train,test,extra}_32x32.mat. |
| 29 | shuffle (bool): shuffle the dataset. |
| 30 | """ |
| 31 | self.shuffle = shuffle |
| 32 | |
| 33 | if name in SVHNDigit._Cache: |
| 34 | self.X, self.Y = SVHNDigit._Cache[name] |
| 35 | return |
| 36 | if data_dir is None: |
| 37 | data_dir = get_dataset_path('svhn_data') |
| 38 | assert name in ['train', 'test', 'extra'], name |
| 39 | filename = os.path.join(data_dir, name + '_32x32.mat') |
| 40 | if not os.path.isfile(filename): |
| 41 | url = SVHN_URL + os.path.basename(filename) |
| 42 | logger.info("File {} not found!".format(filename)) |
| 43 | logger.info("Downloading from {} ...".format(url)) |
| 44 | download(url, os.path.dirname(filename)) |
| 45 | logger.info("Loading {} ...".format(filename)) |
| 46 | data = scipy.io.loadmat(filename) |
| 47 | self.X = data['X'].transpose(3, 0, 1, 2) |
| 48 | self.Y = data['y'].reshape((-1)) |
| 49 | self.Y[self.Y == 10] = 0 |
| 50 | SVHNDigit._Cache[name] = (self.X, self.Y) |
| 51 | |
| 52 | def __len__(self): |
| 53 | return self.X.shape[0] |
nothing calls this directly
no test coverage detected