| 27 | |
| 28 | |
| 29 | class CondGANTrainer(object): |
| 30 | def __init__(self, |
| 31 | model, |
| 32 | dataset=None, |
| 33 | exp_name="model", |
| 34 | ckt_logs_dir="ckt_logs", |
| 35 | ): |
| 36 | """ |
| 37 | :type model: RegularizedGAN |
| 38 | """ |
| 39 | self.model = model |
| 40 | self.dataset = dataset |
| 41 | self.exp_name = exp_name |
| 42 | self.log_dir = ckt_logs_dir |
| 43 | self.checkpoint_dir = ckt_logs_dir |
| 44 | |
| 45 | self.batch_size = cfg.TRAIN.BATCH_SIZE |
| 46 | self.max_epoch = cfg.TRAIN.MAX_EPOCH |
| 47 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL |
| 48 | self.model_path = cfg.TRAIN.PRETRAINED_MODEL |
| 49 | |
| 50 | self.log_vars = [] |
| 51 | |
| 52 | self.hr_image_shape = self.dataset.image_shape |
| 53 | ratio = self.dataset.hr_lr_ratio |
| 54 | self.lr_image_shape = [int(self.hr_image_shape[0] / ratio), |
| 55 | int(self.hr_image_shape[1] / ratio), |
| 56 | self.hr_image_shape[2]] |
| 57 | print('hr_image_shape', self.hr_image_shape) |
| 58 | print('lr_image_shape', self.lr_image_shape) |
| 59 | |
| 60 | def build_placeholder(self): |
| 61 | '''Helper function for init_opt''' |
| 62 | self.hr_images = tf.placeholder( |
| 63 | tf.float32, [self.batch_size] + self.hr_image_shape, |
| 64 | name='real_hr_images') |
| 65 | self.hr_wrong_images = tf.placeholder( |
| 66 | tf.float32, [self.batch_size] + self.hr_image_shape, |
| 67 | name='wrong_hr_images' |
| 68 | ) |
| 69 | self.embeddings = tf.placeholder( |
| 70 | tf.float32, [self.batch_size] + self.dataset.embedding_shape, |
| 71 | name='conditional_embeddings' |
| 72 | ) |
| 73 | |
| 74 | self.generator_lr = tf.placeholder( |
| 75 | tf.float32, [], |
| 76 | name='generator_learning_rate' |
| 77 | ) |
| 78 | self.discriminator_lr = tf.placeholder( |
| 79 | tf.float32, [], |
| 80 | name='discriminator_learning_rate' |
| 81 | ) |
| 82 | # |
| 83 | self.images = tf.image.resize_bilinear(self.hr_images, |
| 84 | self.lr_image_shape[:2]) |
| 85 | self.wrong_images = tf.image.resize_bilinear(self.hr_wrong_images, |
| 86 | self.lr_image_shape[:2]) |