| 12 | |
| 13 | |
| 14 | class CondGAN(object): |
| 15 | def __init__(self, lr_imsize, hr_lr_ratio): |
| 16 | self.batch_size = cfg.TRAIN.BATCH_SIZE |
| 17 | self.network_type = cfg.GAN.NETWORK_TYPE |
| 18 | self.hr_lr_ratio = hr_lr_ratio |
| 19 | self.gf_dim = cfg.GAN.GF_DIM |
| 20 | self.df_dim = cfg.GAN.DF_DIM |
| 21 | self.ef_dim = cfg.GAN.EMBEDDING_DIM |
| 22 | |
| 23 | self.s = lr_imsize |
| 24 | print('lr_imsize: ', lr_imsize) |
| 25 | self.s2, self.s4, self.s8, self.s16 = \ |
| 26 | int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16) |
| 27 | if cfg.GAN.NETWORK_TYPE == "default": |
| 28 | with tf.variable_scope("d_net"): |
| 29 | self.d_context_template = self.context_embedding() |
| 30 | self.d_image_template = self.d_encode_image() |
| 31 | self.d_discriminator_template = self.discriminator() |
| 32 | |
| 33 | with tf.variable_scope("hr_d_net"): |
| 34 | self.hr_d_context_template = self.context_embedding() |
| 35 | self.hr_d_image_template = self.hr_d_encode_image() |
| 36 | self.hr_discriminator_template = self.discriminator() |
| 37 | else: |
| 38 | raise NotImplementedError |
| 39 | |
| 40 | # conditioning augmentation structure for text embedding |
| 41 | # are shared by g and hr_g |
| 42 | # g and hr_g build this structure separately and do not share parameters |
| 43 | def generate_condition(self, c_var): |
| 44 | conditions =\ |
| 45 | (pt.wrap(c_var). |
| 46 | flatten(). |
| 47 | custom_fully_connected(self.ef_dim * 2). |
| 48 | apply(leaky_rectify, leakiness=0.2)) |
| 49 | mean = conditions[:, :self.ef_dim] |
| 50 | log_sigma = conditions[:, self.ef_dim:] |
| 51 | return [mean, log_sigma] |
| 52 | |
| 53 | # stage I generator (g) |
| 54 | def generator(self, z_var): |
| 55 | node1_0 =\ |
| 56 | (pt.wrap(z_var). |
| 57 | flatten(). |
| 58 | custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8). |
| 59 | fc_batch_norm(). |
| 60 | reshape([-1, self.s16, self.s16, self.gf_dim * 8])) |
| 61 | node1_1 = \ |
| 62 | (node1_0. |
| 63 | custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1). |
| 64 | conv_batch_norm(). |
| 65 | apply(tf.nn.relu). |
| 66 | custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1). |
| 67 | conv_batch_norm(). |
| 68 | apply(tf.nn.relu). |
| 69 | custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1). |
| 70 | conv_batch_norm()) |
| 71 | node1 = \ |