Args: name (str): 'train', 'test', 'val' data_dir (str): a directory containing the original 'BSR' directory.
(self, name, data_dir=None, shuffle=True)
| 29 | """ |
| 30 | |
| 31 | def __init__(self, name, data_dir=None, shuffle=True): |
| 32 | """ |
| 33 | Args: |
| 34 | name (str): 'train', 'test', 'val' |
| 35 | data_dir (str): a directory containing the original 'BSR' directory. |
| 36 | """ |
| 37 | # check and download data |
| 38 | if data_dir is None: |
| 39 | data_dir = get_dataset_path('bsds500_data') |
| 40 | if not os.path.isdir(os.path.join(data_dir, 'BSR')): |
| 41 | download(DATA_URL, data_dir, expect_size=DATA_SIZE) |
| 42 | filename = DATA_URL.split('/')[-1] |
| 43 | filepath = os.path.join(data_dir, filename) |
| 44 | import tarfile |
| 45 | tarfile.open(filepath, 'r:gz').extractall(data_dir) |
| 46 | self.data_root = os.path.join(data_dir, 'BSR', 'BSDS500', 'data') |
| 47 | assert os.path.isdir(self.data_root) |
| 48 | |
| 49 | self.shuffle = shuffle |
| 50 | assert name in ['train', 'test', 'val'] |
| 51 | self._load(name) |
| 52 | |
| 53 | def _load(self, name): |
| 54 | image_glob = os.path.join(self.data_root, 'images', name, '*.jpg') |
nothing calls this directly
no test coverage detected