MCPcopy Index your code
hub / github.com/lazyprogrammer/machine_learning_examples / main

Function main

ann_class2/momentum.py:22–198  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

20
21
22def main():
23 # compare 3 scenarios:
24 # 1. batch SGD
25 # 2. batch SGD with momentum
26 # 3. batch SGD with Nesterov momentum
27
28 max_iter = 20 # make it 30 for sigmoid
29 print_period = 50
30
31 Xtrain, Xtest, Ytrain, Ytest = get_normalized_data()
32 lr = 0.00004
33 reg = 0.01
34
35 Ytrain_ind = y2indicator(Ytrain)
36 Ytest_ind = y2indicator(Ytest)
37
38 N, D = Xtrain.shape
39 batch_sz = 500
40 n_batches = N // batch_sz
41
42 M = 300
43 K = 10
44 W1 = np.random.randn(D, M) / np.sqrt(D)
45 b1 = np.zeros(M)
46 W2 = np.random.randn(M, K) / np.sqrt(M)
47 b2 = np.zeros(K)
48
49 # save initial weights
50 W1_0 = W1.copy()
51 b1_0 = b1.copy()
52 W2_0 = W2.copy()
53 b2_0 = b2.copy()
54
55 # 1. batch
56 losses_batch = []
57 errors_batch = []
58 for i in range(max_iter):
59 Xtrain, Ytrain, Ytrain_ind = shuffle(Xtrain, Ytrain, Ytrain_ind)
60 for j in range(n_batches):
61 Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
62 Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
63 pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
64 # print "first batch cost:", cost(pYbatch, Ybatch)
65
66 # gradients
67 gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
68 gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
69 gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
70 gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
71
72 # updates
73 W2 -= lr*gW2
74 b2 -= lr*gb2
75 W1 -= lr*gW1
76 b1 -= lr*gb1
77
78 if j % print_period == 0:
79 pY, _ = forward(Xtest, W1, b1, W2, b2)

Callers 1

momentum.pyFile · 0.70

Calls 10

get_normalized_dataFunction · 0.90
y2indicatorFunction · 0.90
forwardFunction · 0.90
derivative_w2Function · 0.90
derivative_b2Function · 0.90
derivative_w1Function · 0.90
derivative_b1Function · 0.90
costFunction · 0.90
error_rateFunction · 0.90
copyMethod · 0.45

Tested by

no test coverage detected