MCPcopy Index your code
hub / github.com/TheAlgorithms/Python / CNN

Class CNN

neural_network/convolution_neural_network.py:23–351  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

21
22
23class CNN:
24 def __init__(
25 self, conv1_get, size_p1, bp_num1, bp_num2, bp_num3, rate_w=0.2, rate_t=0.2
26 ):
27 """
28 :param conv1_get: [a,c,d], size, number, step of convolution kernel
29 :param size_p1: pooling size
30 :param bp_num1: units number of flatten layer
31 :param bp_num2: units number of hidden layer
32 :param bp_num3: units number of output layer
33 :param rate_w: rate of weight learning
34 :param rate_t: rate of threshold learning
35 """
36 self.num_bp1 = bp_num1
37 self.num_bp2 = bp_num2
38 self.num_bp3 = bp_num3
39 self.conv1 = conv1_get[:2]
40 self.step_conv1 = conv1_get[2]
41 self.size_pooling1 = size_p1
42 self.rate_weight = rate_w
43 self.rate_thre = rate_t
44 rng = np.random.default_rng()
45 self.w_conv1 = [
46 np.asmatrix(-1 * rng.random((self.conv1[0], self.conv1[0])) + 0.5)
47 for i in range(self.conv1[1])
48 ]
49 self.wkj = np.asmatrix(-1 * rng.random((self.num_bp3, self.num_bp2)) + 0.5)
50 self.vji = np.asmatrix(-1 * rng.random((self.num_bp2, self.num_bp1)) + 0.5)
51 self.thre_conv1 = -2 * rng.random(self.conv1[1]) + 1
52 self.thre_bp2 = -2 * rng.random(self.num_bp2) + 1
53 self.thre_bp3 = -2 * rng.random(self.num_bp3) + 1
54
55 def save_model(self, save_path):
56 # save model dict with pickle
57 model_dic = {
58 "num_bp1": self.num_bp1,
59 "num_bp2": self.num_bp2,
60 "num_bp3": self.num_bp3,
61 "conv1": self.conv1,
62 "step_conv1": self.step_conv1,
63 "size_pooling1": self.size_pooling1,
64 "rate_weight": self.rate_weight,
65 "rate_thre": self.rate_thre,
66 "w_conv1": self.w_conv1,
67 "wkj": self.wkj,
68 "vji": self.vji,
69 "thre_conv1": self.thre_conv1,
70 "thre_bp2": self.thre_bp2,
71 "thre_bp3": self.thre_bp3,
72 }
73 with open(save_path, "wb") as f:
74 pickle.dump(model_dic, f)
75
76 print(f"Model saved: {save_path}")
77
78 @classmethod
79 def read_model(cls, model_path):
80 # read saved model

Callers 1

read_modelMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected