Load images from CycleGAN's database, see `this link `__. Parameters ------------ filename : str The dataset you want, see `this link `__. pa
(filename='summer2winter_yosemite', path='data')
| 1218 | |
| 1219 | |
| 1220 | def load_cyclegan_dataset(filename='summer2winter_yosemite', path='data'): |
| 1221 | """Load images from CycleGAN's database, see `this link <https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/>`__. |
| 1222 | |
| 1223 | Parameters |
| 1224 | ------------ |
| 1225 | filename : str |
| 1226 | The dataset you want, see `this link <https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/>`__. |
| 1227 | path : str |
| 1228 | The path that the data is downloaded to, defaults is `data/cyclegan` |
| 1229 | |
| 1230 | Examples |
| 1231 | --------- |
| 1232 | >>> im_train_A, im_train_B, im_test_A, im_test_B = load_cyclegan_dataset(filename='summer2winter_yosemite') |
| 1233 | |
| 1234 | """ |
| 1235 | path = os.path.join(path, 'cyclegan') |
| 1236 | url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/' |
| 1237 | |
| 1238 | if folder_exists(os.path.join(path, filename)) is False: |
| 1239 | logging.info("[*] {} is nonexistent in {}".format(filename, path)) |
| 1240 | maybe_download_and_extract(filename + '.zip', path, url, extract=True) |
| 1241 | del_file(os.path.join(path, filename + '.zip')) |
| 1242 | |
| 1243 | def load_image_from_folder(path): |
| 1244 | path_imgs = load_file_list(path=path, regx='\\.jpg', printable=False) |
| 1245 | return visualize.read_images(path_imgs, path=path, n_threads=10, printable=False) |
| 1246 | |
| 1247 | im_train_A = load_image_from_folder(os.path.join(path, filename, "trainA")) |
| 1248 | im_train_B = load_image_from_folder(os.path.join(path, filename, "trainB")) |
| 1249 | im_test_A = load_image_from_folder(os.path.join(path, filename, "testA")) |
| 1250 | im_test_B = load_image_from_folder(os.path.join(path, filename, "testB")) |
| 1251 | |
| 1252 | def if_2d_to_3d(images): # [h, w] --> [h, w, 3] |
| 1253 | for i, _v in enumerate(images): |
| 1254 | if len(images[i].shape) == 2: |
| 1255 | images[i] = images[i][:, :, np.newaxis] |
| 1256 | images[i] = np.tile(images[i], (1, 1, 3)) |
| 1257 | return images |
| 1258 | |
| 1259 | im_train_A = if_2d_to_3d(im_train_A) |
| 1260 | im_train_B = if_2d_to_3d(im_train_B) |
| 1261 | im_test_A = if_2d_to_3d(im_test_A) |
| 1262 | im_test_B = if_2d_to_3d(im_test_B) |
| 1263 | |
| 1264 | return im_train_A, im_train_B, im_test_A, im_test_B |
| 1265 | |
| 1266 | |
| 1267 | def download_file_from_google_drive(ID, destination): |
nothing calls this directly
no test coverage detected
searching dependent graphs…