MCPcopy
hub / github.com/hanzhanggit/StackGAN / CondGAN

Class CondGAN

stageI/model.py:11–224  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

9
10
11class CondGAN(object):
12 def __init__(self, image_shape):
13 self.batch_size = cfg.TRAIN.BATCH_SIZE
14 self.network_type = cfg.GAN.NETWORK_TYPE
15 self.image_shape = image_shape
16 self.gf_dim = cfg.GAN.GF_DIM
17 self.df_dim = cfg.GAN.DF_DIM
18 self.ef_dim = cfg.GAN.EMBEDDING_DIM
19
20 self.image_shape = image_shape
21 self.s = image_shape[0]
22 self.s2, self.s4, self.s8, self.s16 =\
23 int(self.s / 2), int(self.s / 4), int(self.s / 8), int(self.s / 16)
24
25 # Since D is only used during training, we build a template
26 # for safe reuse the variables during computing loss for fake/real/wrong images
27 # We do not do this for G,
28 # because batch_norm needs different options for training and testing
29 if cfg.GAN.NETWORK_TYPE == "default":
30 with tf.variable_scope("d_net"):
31 self.d_encode_img_template = self.d_encode_image()
32 self.d_context_template = self.context_embedding()
33 self.discriminator_template = self.discriminator()
34 elif cfg.GAN.NETWORK_TYPE == "simple":
35 with tf.variable_scope("d_net"):
36 self.d_encode_img_template = self.d_encode_image_simple()
37 self.d_context_template = self.context_embedding()
38 self.discriminator_template = self.discriminator()
39 else:
40 raise NotImplementedError
41
42 # g-net
43 def generate_condition(self, c_var):
44 conditions =\
45 (pt.wrap(c_var).
46 flatten().
47 custom_fully_connected(self.ef_dim * 2).
48 apply(leaky_rectify, leakiness=0.2))
49 mean = conditions[:, :self.ef_dim]
50 log_sigma = conditions[:, self.ef_dim:]
51 return [mean, log_sigma]
52
53 def generator(self, z_var):
54 node1_0 =\
55 (pt.wrap(z_var).
56 flatten().
57 custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
58 fc_batch_norm().
59 reshape([-1, self.s16, self.s16, self.gf_dim * 8]))
60 node1_1 = \
61 (node1_0.
62 custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
63 conv_batch_norm().
64 apply(tf.nn.relu).
65 custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
66 conv_batch_norm().
67 apply(tf.nn.relu).
68 custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).

Callers 3

build_modelFunction · 0.90
build_modelFunction · 0.90
run_exp.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected