()
| 19 | |
| 20 | |
| 21 | def example_0(): |
| 22 | # learns to repeat simple sequence from random inputs |
| 23 | np.random.seed(0) |
| 24 | |
| 25 | # parameters for input data dimension and lstm cell count |
| 26 | mem_cell_ct = 100 |
| 27 | x_dim = 50 |
| 28 | lstm_param = LstmParam(mem_cell_ct, x_dim) |
| 29 | lstm_net = LstmNetwork(lstm_param) |
| 30 | y_list = [-0.5, 0.2, 0.1, -0.5] |
| 31 | input_val_arr = [np.random.random(x_dim) for _ in y_list] |
| 32 | |
| 33 | for cur_iter in range(100): |
| 34 | print("iter", "%2s" % str(cur_iter), end=": ") |
| 35 | for ind in range(len(y_list)): |
| 36 | lstm_net.x_list_add(input_val_arr[ind]) |
| 37 | |
| 38 | print("y_pred = [" + |
| 39 | ", ".join(["% 2.5f" % lstm_net.lstm_node_list[ind].state.h[0] for ind in range(len(y_list))]) + |
| 40 | "]", end=", ") |
| 41 | |
| 42 | loss = lstm_net.y_list_is(y_list, ToyLossLayer) |
| 43 | print("loss:", "%.3e" % loss) |
| 44 | lstm_param.apply_diff(lr=0.1) |
| 45 | lstm_net.x_list_clear() |
| 46 | |
| 47 | |
| 48 | if __name__ == "__main__": |
no test coverage detected