The routing algorithm for one capsule in the layer l+1. Args: input: A Tensor with [batch_size, num_caps_l=1152, length(u_i)=8, 1] shape, num_caps_l meaning the number of capsule in the layer l. Returns: A Tensor of shape [batch_size, 1, length(v_j)=16, 1] re
(input, b_IJ, idx_j)
| 70 | |
| 71 | |
| 72 | def capsule(input, b_IJ, idx_j): |
| 73 | ''' The routing algorithm for one capsule in the layer l+1. |
| 74 | |
| 75 | Args: |
| 76 | input: A Tensor with [batch_size, num_caps_l=1152, length(u_i)=8, 1] |
| 77 | shape, num_caps_l meaning the number of capsule in the layer l. |
| 78 | Returns: |
| 79 | A Tensor of shape [batch_size, 1, length(v_j)=16, 1] representing the |
| 80 | vector output `v_j` of capsule j in the layer l+1 |
| 81 | Notes: |
| 82 | u_i represents the vector output of capsule i in the layer l, and |
| 83 | v_j the vector output of capsule j in the layer l+1. |
| 84 | ''' |
| 85 | |
| 86 | with tf.variable_scope('routing'): |
| 87 | w_initializer = np.random.normal(size=[1, 1152, 8, 16], scale=0.01) |
| 88 | W_Ij = tf.Variable(w_initializer, dtype=tf.float32) |
| 89 | # repeat W_Ij with batch_size times to shape [batch_size, 1152, 8, 16] |
| 90 | W_Ij = tf.tile(W_Ij, [cfg.batch_size, 1, 1, 1]) |
| 91 | |
| 92 | # calc u_hat |
| 93 | # [8, 16].T x [8, 1] => [16, 1] => [batch_size, 1152, 16, 1] |
| 94 | u_hat = tf.matmul(W_Ij, input, transpose_a=True) |
| 95 | assert u_hat.get_shape() == [cfg.batch_size, 1152, 16, 1] |
| 96 | |
| 97 | shape = b_IJ.get_shape().as_list() |
| 98 | size_splits = [idx_j, 1, shape[2] - idx_j - 1] |
| 99 | for r_iter in range(cfg.iter_routing): |
| 100 | # line 4: |
| 101 | # [1, 1152, 10, 1] |
| 102 | c_IJ = tf.nn.softmax(b_IJ, dim=2) |
| 103 | assert c_IJ.get_shape() == [1, 1152, 10, 1] |
| 104 | |
| 105 | # line 5: |
| 106 | # weighting u_hat with c_I in the third dim, |
| 107 | # then sum in the second dim, resulting in [batch_size, 1, 16, 1] |
| 108 | b_Il, b_Ij, b_Ir = tf.split(b_IJ, size_splits, axis=2) |
| 109 | c_Il, c_Ij, b_Ir = tf.split(c_IJ, size_splits, axis=2) |
| 110 | assert c_Ij.get_shape() == [1, 1152, 1, 1] |
| 111 | |
| 112 | s_j = tf.multiply(c_Ij, u_hat) |
| 113 | s_j = tf.reduce_sum(tf.multiply(c_Ij, u_hat), |
| 114 | axis=1, keep_dims=True) |
| 115 | assert s_j.get_shape() == [cfg.batch_size, 1, 16, 1] |
| 116 | |
| 117 | # line 6: |
| 118 | # squash using Eq.1, resulting in [batch_size, 1, 16, 1] |
| 119 | v_j = squash(s_j) |
| 120 | assert s_j.get_shape() == [cfg.batch_size, 1, 16, 1] |
| 121 | |
| 122 | # line 7: |
| 123 | # tile v_j from [batch_size ,1, 16, 1] to [batch_size, 1152, 16, 1] |
| 124 | # [16, 1].T x [16, 1] => [1, 1], then reduce mean in the |
| 125 | # batch_size dim, resulting in [1, 1152, 1, 1] |
| 126 | v_j_tiled = tf.tile(v_j, [1, 1152, 1, 1]) |
| 127 | u_produce_v = tf.matmul(u_hat, v_j_tiled, transpose_a=True) |
| 128 | assert u_produce_v.get_shape() == [cfg.batch_size, 1152, 1, 1] |
| 129 | b_Ij += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True) |