| 21 | |
| 22 | |
| 23 | class CNN: |
| 24 | def __init__( |
| 25 | self, conv1_get, size_p1, bp_num1, bp_num2, bp_num3, rate_w=0.2, rate_t=0.2 |
| 26 | ): |
| 27 | """ |
| 28 | :param conv1_get: [a,c,d], size, number, step of convolution kernel |
| 29 | :param size_p1: pooling size |
| 30 | :param bp_num1: units number of flatten layer |
| 31 | :param bp_num2: units number of hidden layer |
| 32 | :param bp_num3: units number of output layer |
| 33 | :param rate_w: rate of weight learning |
| 34 | :param rate_t: rate of threshold learning |
| 35 | """ |
| 36 | self.num_bp1 = bp_num1 |
| 37 | self.num_bp2 = bp_num2 |
| 38 | self.num_bp3 = bp_num3 |
| 39 | self.conv1 = conv1_get[:2] |
| 40 | self.step_conv1 = conv1_get[2] |
| 41 | self.size_pooling1 = size_p1 |
| 42 | self.rate_weight = rate_w |
| 43 | self.rate_thre = rate_t |
| 44 | rng = np.random.default_rng() |
| 45 | self.w_conv1 = [ |
| 46 | np.asmatrix(-1 * rng.random((self.conv1[0], self.conv1[0])) + 0.5) |
| 47 | for i in range(self.conv1[1]) |
| 48 | ] |
| 49 | self.wkj = np.asmatrix(-1 * rng.random((self.num_bp3, self.num_bp2)) + 0.5) |
| 50 | self.vji = np.asmatrix(-1 * rng.random((self.num_bp2, self.num_bp1)) + 0.5) |
| 51 | self.thre_conv1 = -2 * rng.random(self.conv1[1]) + 1 |
| 52 | self.thre_bp2 = -2 * rng.random(self.num_bp2) + 1 |
| 53 | self.thre_bp3 = -2 * rng.random(self.num_bp3) + 1 |
| 54 | |
| 55 | def save_model(self, save_path): |
| 56 | # save model dict with pickle |
| 57 | model_dic = { |
| 58 | "num_bp1": self.num_bp1, |
| 59 | "num_bp2": self.num_bp2, |
| 60 | "num_bp3": self.num_bp3, |
| 61 | "conv1": self.conv1, |
| 62 | "step_conv1": self.step_conv1, |
| 63 | "size_pooling1": self.size_pooling1, |
| 64 | "rate_weight": self.rate_weight, |
| 65 | "rate_thre": self.rate_thre, |
| 66 | "w_conv1": self.w_conv1, |
| 67 | "wkj": self.wkj, |
| 68 | "vji": self.vji, |
| 69 | "thre_conv1": self.thre_conv1, |
| 70 | "thre_bp2": self.thre_bp2, |
| 71 | "thre_bp3": self.thre_bp3, |
| 72 | } |
| 73 | with open(save_path, "wb") as f: |
| 74 | pickle.dump(model_dic, f) |
| 75 | |
| 76 | print(f"Model saved: {save_path}") |
| 77 | |
| 78 | @classmethod |
| 79 | def read_model(cls, model_path): |
| 80 | # read saved model |