| 33 | tl.files.exists_or_mkdir(checkpoint_dir) |
| 34 | |
| 35 | def get_train_data(): |
| 36 | # load dataset |
| 37 | train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20] |
| 38 | # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) |
| 39 | # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) |
| 40 | # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) |
| 41 | |
| 42 | ## If your machine have enough memory, please pre-load the entire train set. |
| 43 | train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) |
| 44 | # for im in train_hr_imgs: |
| 45 | # print(im.shape) |
| 46 | # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) |
| 47 | # for im in valid_lr_imgs: |
| 48 | # print(im.shape) |
| 49 | # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) |
| 50 | # for im in valid_hr_imgs: |
| 51 | # print(im.shape) |
| 52 | |
| 53 | # dataset API and augmentation |
| 54 | def generator_train(): |
| 55 | for img in train_hr_imgs: |
| 56 | yield img |
| 57 | def _map_fn_train(img): |
| 58 | hr_patch = tf.image.random_crop(img, [384, 384, 3]) |
| 59 | hr_patch = hr_patch / (255. / 2.) |
| 60 | hr_patch = hr_patch - 1. |
| 61 | hr_patch = tf.image.random_flip_left_right(hr_patch) |
| 62 | lr_patch = tf.image.resize(hr_patch, size=[96, 96]) |
| 63 | return lr_patch, hr_patch |
| 64 | train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32)) |
| 65 | train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) |
| 66 | # train_ds = train_ds.repeat(n_epoch_init + n_epoch) |
| 67 | train_ds = train_ds.shuffle(shuffle_buffer_size) |
| 68 | train_ds = train_ds.prefetch(buffer_size=2) |
| 69 | train_ds = train_ds.batch(batch_size) |
| 70 | # value = train_ds.make_one_shot_iterator().get_next() |
| 71 | return train_ds |
| 72 | |
| 73 | def train(): |
| 74 | G = get_G((batch_size, 96, 96, 3)) |