MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / build_graph

Method build_graph

examples/GAN/Improved-WGAN.py:45–79  ·  view source on GitHub ↗
(self, image_pos)

Source from the content-addressed store, hash-verified

43 return tf.reshape(l, [-1])
44
45 def build_graph(self, image_pos):
46 image_pos = image_pos / 128.0 - 1
47
48 z = tf.random_normal([self.batch, self.zdim], name='z_train')
49 z = tf.placeholder_with_default(z, [None, self.zdim], name='z')
50
51 with argscope([Conv2D, Conv2DTranspose, FullyConnected],
52 kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)):
53 with tf.variable_scope('gen'):
54 image_gen = self.generator(z)
55 tf.summary.image('generated-samples', image_gen, max_outputs=30)
56
57 alpha = tf.random_uniform(shape=[self.batch, 1, 1, 1],
58 minval=0., maxval=1., name='alpha')
59 interp = image_pos + alpha * (image_gen - image_pos)
60
61 with tf.variable_scope('discrim'):
62 vecpos = self.discriminator(image_pos)
63 vecneg = self.discriminator(image_gen)
64 vec_interp = self.discriminator(interp)
65
66 # the Wasserstein-GAN losses
67 self.d_loss = tf.reduce_mean(vecneg - vecpos, name='d_loss')
68 self.g_loss = tf.negative(tf.reduce_mean(vecneg), name='g_loss')
69
70 # the gradient penalty loss
71 gradients = tf.gradients(vec_interp, [interp])[0]
72 gradients = tf.sqrt(tf.reduce_sum(tf.square(gradients), [1, 2, 3]))
73 gradients_rms = tf.sqrt(tf.reduce_mean(tf.square(gradients)), name='gradient_rms')
74 gradient_penalty = tf.reduce_mean(tf.square(gradients - 1), name='gradient_penalty')
75 add_moving_summary(self.d_loss, self.g_loss, gradient_penalty, gradients_rms)
76
77 self.d_loss = tf.add(self.d_loss, 10 * gradient_penalty)
78
79 self.collect_variables()
80
81 def optimizer(self):
82 opt = tf.train.AdamOptimizer(1e-4, beta1=0.5, beta2=0.9)

Callers

nothing calls this directly

Calls 6

generatorMethod · 0.95
discriminatorMethod · 0.95
add_moving_summaryFunction · 0.90
argscopeFunction · 0.85
addMethod · 0.45
collect_variablesMethod · 0.45

Tested by

no test coverage detected