Squashing function. Args: vector: A 4-D tensor with shape [batch_size, num_caps, vec_len, 1], Returns: A 4-D tensor with the same shape as vector but squashed in 3rd and 4th dimensions.
(vector)
| 133 | |
| 134 | |
| 135 | def squash(vector): |
| 136 | '''Squashing function. |
| 137 | Args: |
| 138 | vector: A 4-D tensor with shape [batch_size, num_caps, vec_len, 1], |
| 139 | Returns: |
| 140 | A 4-D tensor with the same shape as vector but |
| 141 | squashed in 3rd and 4th dimensions. |
| 142 | ''' |
| 143 | vec_abs = tf.sqrt(tf.reduce_sum(tf.square(vector))) # a scalar |
| 144 | scalar_factor = tf.square(vec_abs) / (1 + tf.square(vec_abs)) |
| 145 | vec_squashed = scalar_factor * tf.divide(vector, vec_abs) # element-wise |
| 146 | return(vec_squashed) |