Flatten the tensor except the first dimension.
(x)
| 13 | |
| 14 | |
| 15 | def batch_flatten(x): |
| 16 | """ |
| 17 | Flatten the tensor except the first dimension. |
| 18 | """ |
| 19 | shape = x.get_shape().as_list()[1:] |
| 20 | if None not in shape: |
| 21 | return tf.reshape(x, [-1, int(np.prod(shape))]) |
| 22 | return tf.reshape(x, tf.stack([tf.shape(x)[0], -1])) |
| 23 | |
| 24 | |
| 25 | @layer_register(log_shape=True) |