(x, T, D, batch_sz)
| 27 | |
| 28 | |
| 29 | def x2sequence(x, T, D, batch_sz): |
| 30 | # Permuting batch_size and n_steps |
| 31 | x = tf.transpose(x, (1, 0, 2)) |
| 32 | # Reshaping to (n_steps*batch_size, n_input) |
| 33 | x = tf.reshape(x, (T*batch_sz, D)) |
| 34 | # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input) |
| 35 | # x = tf.split(0, T, x) # v0.1 |
| 36 | x = tf.split(x, T) # v1.0 |
| 37 | # print "type(x):", type(x) |
| 38 | return x |
| 39 | |
| 40 | class SimpleRNN: |
| 41 | def __init__(self, M): |