MCPcopy
hub / github.com/hanzhanggit/StackGAN / CondGAN

Class CondGAN

stageII/model.py:14–321  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13
14class CondGAN(object):
15 def __init__(self, lr_imsize, hr_lr_ratio):
16 self.batch_size = cfg.TRAIN.BATCH_SIZE
17 self.network_type = cfg.GAN.NETWORK_TYPE
18 self.hr_lr_ratio = hr_lr_ratio
19 self.gf_dim = cfg.GAN.GF_DIM
20 self.df_dim = cfg.GAN.DF_DIM
21 self.ef_dim = cfg.GAN.EMBEDDING_DIM
22
23 self.s = lr_imsize
24 print('lr_imsize: ', lr_imsize)
25 self.s2, self.s4, self.s8, self.s16 = \
26 int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16)
27 if cfg.GAN.NETWORK_TYPE == "default":
28 with tf.variable_scope("d_net"):
29 self.d_context_template = self.context_embedding()
30 self.d_image_template = self.d_encode_image()
31 self.d_discriminator_template = self.discriminator()
32
33 with tf.variable_scope("hr_d_net"):
34 self.hr_d_context_template = self.context_embedding()
35 self.hr_d_image_template = self.hr_d_encode_image()
36 self.hr_discriminator_template = self.discriminator()
37 else:
38 raise NotImplementedError
39
40 # conditioning augmentation structure for text embedding
41 # are shared by g and hr_g
42 # g and hr_g build this structure separately and do not share parameters
43 def generate_condition(self, c_var):
44 conditions =\
45 (pt.wrap(c_var).
46 flatten().
47 custom_fully_connected(self.ef_dim * 2).
48 apply(leaky_rectify, leakiness=0.2))
49 mean = conditions[:, :self.ef_dim]
50 log_sigma = conditions[:, self.ef_dim:]
51 return [mean, log_sigma]
52
53 # stage I generator (g)
54 def generator(self, z_var):
55 node1_0 =\
56 (pt.wrap(z_var).
57 flatten().
58 custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
59 fc_batch_norm().
60 reshape([-1, self.s16, self.s16, self.gf_dim * 8]))
61 node1_1 = \
62 (node1_0.
63 custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
64 conv_batch_norm().
65 apply(tf.nn.relu).
66 custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
67 conv_batch_norm().
68 apply(tf.nn.relu).
69 custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
70 conv_batch_norm())
71 node1 = \

Callers 1

run_exp.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected