| 26 | |
| 27 | |
| 28 | class CondGANTrainer(object): |
| 29 | def __init__(self, |
| 30 | model, |
| 31 | dataset=None, |
| 32 | exp_name="model", |
| 33 | ckt_logs_dir="ckt_logs", |
| 34 | ): |
| 35 | """ |
| 36 | :type model: RegularizedGAN |
| 37 | """ |
| 38 | self.model = model |
| 39 | self.dataset = dataset |
| 40 | self.exp_name = exp_name |
| 41 | self.log_dir = ckt_logs_dir |
| 42 | self.checkpoint_dir = ckt_logs_dir |
| 43 | |
| 44 | self.batch_size = cfg.TRAIN.BATCH_SIZE |
| 45 | self.max_epoch = cfg.TRAIN.MAX_EPOCH |
| 46 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL |
| 47 | self.model_path = cfg.TRAIN.PRETRAINED_MODEL |
| 48 | |
| 49 | self.log_vars = [] |
| 50 | |
| 51 | def build_placeholder(self): |
| 52 | '''Helper function for init_opt''' |
| 53 | self.images = tf.placeholder( |
| 54 | tf.float32, [self.batch_size] + self.dataset.image_shape, |
| 55 | name='real_images') |
| 56 | self.wrong_images = tf.placeholder( |
| 57 | tf.float32, [self.batch_size] + self.dataset.image_shape, |
| 58 | name='wrong_images' |
| 59 | ) |
| 60 | self.embeddings = tf.placeholder( |
| 61 | tf.float32, [self.batch_size] + self.dataset.embedding_shape, |
| 62 | name='conditional_embeddings' |
| 63 | ) |
| 64 | |
| 65 | self.generator_lr = tf.placeholder( |
| 66 | tf.float32, [], |
| 67 | name='generator_learning_rate' |
| 68 | ) |
| 69 | self.discriminator_lr = tf.placeholder( |
| 70 | tf.float32, [], |
| 71 | name='discriminator_learning_rate' |
| 72 | ) |
| 73 | |
| 74 | def sample_encoded_context(self, embeddings): |
| 75 | '''Helper function for init_opt''' |
| 76 | c_mean_logsigma = self.model.generate_condition(embeddings) |
| 77 | mean = c_mean_logsigma[0] |
| 78 | if cfg.TRAIN.COND_AUGMENTATION: |
| 79 | # epsilon = tf.random_normal(tf.shape(mean)) |
| 80 | epsilon = tf.truncated_normal(tf.shape(mean)) |
| 81 | stddev = tf.exp(c_mean_logsigma[1]) |
| 82 | c = mean + stddev * epsilon |
| 83 | |
| 84 | kl_loss = KL_loss(c_mean_logsigma[0], c_mean_logsigma[1]) |
| 85 | else: |