| 9 | |
| 10 | |
| 11 | class CondGAN(object): |
| 12 | def __init__(self, image_shape): |
| 13 | self.batch_size = cfg.TRAIN.BATCH_SIZE |
| 14 | self.network_type = cfg.GAN.NETWORK_TYPE |
| 15 | self.image_shape = image_shape |
| 16 | self.gf_dim = cfg.GAN.GF_DIM |
| 17 | self.df_dim = cfg.GAN.DF_DIM |
| 18 | self.ef_dim = cfg.GAN.EMBEDDING_DIM |
| 19 | |
| 20 | self.image_shape = image_shape |
| 21 | self.s = image_shape[0] |
| 22 | self.s2, self.s4, self.s8, self.s16 =\ |
| 23 | int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16) |
| 24 | |
| 25 | # Since D is only used during training, we build a template |
| 26 | # for safe reuse the variables during computing loss for fake/real/wrong images |
| 27 | # We do not do this for G, |
| 28 | # because batch_norm needs different options for training and testing |
| 29 | if cfg.GAN.NETWORK_TYPE == "default": |
| 30 | with tf.variable_scope("d_net"): |
| 31 | self.d_encode_img_template = self.d_encode_image() |
| 32 | self.d_context_template = self.context_embedding() |
| 33 | self.discriminator_template = self.discriminator() |
| 34 | elif cfg.GAN.NETWORK_TYPE == "simple": |
| 35 | with tf.variable_scope("d_net"): |
| 36 | self.d_encode_img_template = self.d_encode_image_simple() |
| 37 | self.d_context_template = self.context_embedding() |
| 38 | self.discriminator_template = self.discriminator() |
| 39 | else: |
| 40 | raise NotImplementedError |
| 41 | |
| 42 | # g-net |
| 43 | def generate_condition(self, c_var): |
| 44 | conditions =\ |
| 45 | (pt.wrap(c_var). |
| 46 | flatten(). |
| 47 | custom_fully_connected(self.ef_dim * 2). |
| 48 | apply(leaky_rectify, leakiness=0.2)) |
| 49 | mean = conditions[:, :self.ef_dim] |
| 50 | log_sigma = conditions[:, self.ef_dim:] |
| 51 | return [mean, log_sigma] |
| 52 | |
| 53 | def generator(self, z_var): |
| 54 | node1_0 =\ |
| 55 | (pt.wrap(z_var). |
| 56 | flatten(). |
| 57 | custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8). |
| 58 | fc_batch_norm(). |
| 59 | reshape([-1, self.s16, self.s16, self.gf_dim * 8])) |
| 60 | node1_1 = \ |
| 61 | (node1_0. |
| 62 | custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1). |
| 63 | conv_batch_norm(). |
| 64 | apply(tf.nn.relu). |
| 65 | custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1). |
| 66 | conv_batch_norm(). |
| 67 | apply(tf.nn.relu). |
| 68 | custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1). |
no outgoing calls
no test coverage detected