MCPcopy
hub / github.com/nicodjimenez/lstm / example_0

Function example_0

test.py:21–45  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

19
20
21def example_0():
22 # learns to repeat simple sequence from random inputs
23 np.random.seed(0)
24
25 # parameters for input data dimension and lstm cell count
26 mem_cell_ct = 100
27 x_dim = 50
28 lstm_param = LstmParam(mem_cell_ct, x_dim)
29 lstm_net = LstmNetwork(lstm_param)
30 y_list = [-0.5, 0.2, 0.1, -0.5]
31 input_val_arr = [np.random.random(x_dim) for _ in y_list]
32
33 for cur_iter in range(100):
34 print("iter", "%2s" % str(cur_iter), end=": ")
35 for ind in range(len(y_list)):
36 lstm_net.x_list_add(input_val_arr[ind])
37
38 print("y_pred = [" +
39 ", ".join(["% 2.5f" % lstm_net.lstm_node_list[ind].state.h[0] for ind in range(len(y_list))]) +
40 "]", end=", ")
41
42 loss = lstm_net.y_list_is(y_list, ToyLossLayer)
43 print("loss:", "%.3e" % loss)
44 lstm_param.apply_diff(lr=0.1)
45 lstm_net.x_list_clear()
46
47
48if __name__ == "__main__":

Callers 1

test.pyFile · 0.85

Calls 6

x_list_addMethod · 0.95
y_list_isMethod · 0.95
apply_diffMethod · 0.95
x_list_clearMethod · 0.95
LstmParamClass · 0.90
LstmNetworkClass · 0.90

Tested by

no test coverage detected