| 86 | self.decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid) |
| 87 | |
| 88 | def loss(self): |
| 89 | # 1. The margin loss |
| 90 | |
| 91 | # [batch_size, 10, 1, 1] |
| 92 | # max_l = max(0, m_plus-||v_c||)^2 |
| 93 | max_l = tf.square(tf.maximum(0., cfg.m_plus - self.v_length)) |
| 94 | # max_r = max(0, ||v_c||-m_minus)^2 |
| 95 | max_r = tf.square(tf.maximum(0., self.v_length - cfg.m_minus)) |
| 96 | assert max_l.get_shape() == [cfg.batch_size, 10, 1, 1] |
| 97 | |
| 98 | # reshape: [batch_size, 10, 1, 1] => [batch_size, 10] |
| 99 | max_l = tf.reshape(max_l, shape=(cfg.batch_size, -1)) |
| 100 | max_r = tf.reshape(max_r, shape=(cfg.batch_size, -1)) |
| 101 | |
| 102 | # calc T_c: [batch_size, 10] |
| 103 | # T_c = Y, is my understanding correct? Try it. |
| 104 | T_c = self.Y |
| 105 | # [batch_size, 10], element-wise multiply |
| 106 | L_c = T_c * max_l + cfg.lambda_val * (1 - T_c) * max_r |
| 107 | |
| 108 | self.margin_loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1)) |
| 109 | |
| 110 | # 2. The reconstruction loss |
| 111 | orgin = tf.reshape(self.X, shape=(cfg.batch_size, -1)) |
| 112 | squared = tf.square(self.decoded - orgin) |
| 113 | self.reconstruction_err = tf.reduce_mean(squared) |
| 114 | |
| 115 | # 3. Total loss |
| 116 | self.total_loss = self.margin_loss + 0.0005 * self.reconstruction_err |
| 117 | |
| 118 | # Summary |
| 119 | tf.summary.scalar('margin_loss', self.margin_loss) |
| 120 | tf.summary.scalar('reconstruction_loss', self.reconstruction_err) |
| 121 | tf.summary.scalar('total_loss', self.total_loss) |
| 122 | recon_img = tf.reshape(self.decoded, shape=(cfg.batch_size, 28, 28, 1)) |
| 123 | tf.summary.image('reconstruction_img', recon_img) |
| 124 | self.merged_sum = tf.summary.merge_all() |