MCPcopy Index your code
hub / github.com/tensorlayer/SRGAN / get_train_data

Function get_train_data

train.py:35–71  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

33tl.files.exists_or_mkdir(checkpoint_dir)
34
35def get_train_data():
36 # load dataset
37 train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20]
38 # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
39 # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
40 # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))
41
42 ## If your machine have enough memory, please pre-load the entire train set.
43 train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
44 # for im in train_hr_imgs:
45 # print(im.shape)
46 # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
47 # for im in valid_lr_imgs:
48 # print(im.shape)
49 # valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
50 # for im in valid_hr_imgs:
51 # print(im.shape)
52
53 # dataset API and augmentation
54 def generator_train():
55 for img in train_hr_imgs:
56 yield img
57 def _map_fn_train(img):
58 hr_patch = tf.image.random_crop(img, [384, 384, 3])
59 hr_patch = hr_patch / (255. / 2.)
60 hr_patch = hr_patch - 1.
61 hr_patch = tf.image.random_flip_left_right(hr_patch)
62 lr_patch = tf.image.resize(hr_patch, size=[96, 96])
63 return lr_patch, hr_patch
64 train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32))
65 train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
66 # train_ds = train_ds.repeat(n_epoch_init + n_epoch)
67 train_ds = train_ds.shuffle(shuffle_buffer_size)
68 train_ds = train_ds.prefetch(buffer_size=2)
69 train_ds = train_ds.batch(batch_size)
70 # value = train_ds.make_one_shot_iterator().get_next()
71 return train_ds
72
73def train():
74 G = get_G((batch_size, 96, 96, 3))

Callers 1

trainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected