| 77 | |
| 78 | @classmethod |
| 79 | def read_model(cls, model_path): |
| 80 | # read saved model |
| 81 | with open(model_path, "rb") as f: |
| 82 | model_dic = pickle.load(f) # noqa: S301 |
| 83 | |
| 84 | conv_get = model_dic.get("conv1") |
| 85 | conv_get.append(model_dic.get("step_conv1")) |
| 86 | size_p1 = model_dic.get("size_pooling1") |
| 87 | bp1 = model_dic.get("num_bp1") |
| 88 | bp2 = model_dic.get("num_bp2") |
| 89 | bp3 = model_dic.get("num_bp3") |
| 90 | r_w = model_dic.get("rate_weight") |
| 91 | r_t = model_dic.get("rate_thre") |
| 92 | # create model instance |
| 93 | conv_ins = CNN(conv_get, size_p1, bp1, bp2, bp3, r_w, r_t) |
| 94 | # modify model parameter |
| 95 | conv_ins.w_conv1 = model_dic.get("w_conv1") |
| 96 | conv_ins.wkj = model_dic.get("wkj") |
| 97 | conv_ins.vji = model_dic.get("vji") |
| 98 | conv_ins.thre_conv1 = model_dic.get("thre_conv1") |
| 99 | conv_ins.thre_bp2 = model_dic.get("thre_bp2") |
| 100 | conv_ins.thre_bp3 = model_dic.get("thre_bp3") |
| 101 | return conv_ins |
| 102 | |
| 103 | def sig(self, x): |
| 104 | return 1 / (1 + np.exp(-1 * x)) |