函数作用:模型初始化 Conv1 -> Pool1 -> Conv2 -> Pool2 -> Flatten -> FC3 -> FC4 -> FC5 -> Softmax
(self)
| 857 | self.is_initialized = False |
| 858 | |
| 859 | def _set_params(self): |
| 860 | """ |
| 861 | 函数作用:模型初始化 |
| 862 | Conv1 -> Pool1 -> Conv2 -> Pool2 -> Flatten -> FC3 -> FC4 -> FC5 -> Softmax |
| 863 | """ |
| 864 | self.layers = OrderedDict() |
| 865 | self.layers["Conv1"] = Conv2D( |
| 866 | out_ch=self.conv1_out_ch, |
| 867 | kernel_shape=self.conv1_kernel_shape, |
| 868 | pad=self.conv1_pad, |
| 869 | stride=self.conv1_stride, |
| 870 | acti_fn="sigmoid", |
| 871 | optimizer=self.optimizer, |
| 872 | init_w=self.init_w, |
| 873 | ) |
| 874 | self.layers["Pool1"] = Pool2D( |
| 875 | mode="max", |
| 876 | optimizer=self.optimizer, |
| 877 | stride=self.pool1_stride, |
| 878 | kernel_shape=self.pool1_kernel_shape, |
| 879 | ) |
| 880 | self.layers["Conv2"] = Conv2D( |
| 881 | out_ch=self.conv1_out_ch, |
| 882 | kernel_shape=self.conv1_kernel_shape, |
| 883 | pad=self.conv1_pad, |
| 884 | stride=self.conv1_stride, |
| 885 | acti_fn="sigmoid", |
| 886 | optimizer=self.optimizer, |
| 887 | init_w=self.init_w, |
| 888 | ) |
| 889 | self.layers["Pool2"] = Pool2D( |
| 890 | mode="max", |
| 891 | optimizer=self.optimizer, |
| 892 | stride=self.pool2_stride, |
| 893 | kernel_shape=self.pool2_kernel_shape, |
| 894 | ) |
| 895 | self.layers["Flatten"] = Flatten(optimizer=self.optimizer) |
| 896 | self.layers["FC3"] = FullyConnected( |
| 897 | n_out=self.fc3_out, |
| 898 | acti_fn="sigmoid", |
| 899 | init_w=self.init_w, |
| 900 | optimizer=self.optimizer |
| 901 | ) |
| 902 | self.layers["FC4"] = FullyConnected( |
| 903 | n_out=self.fc4_out, |
| 904 | acti_fn="sigmoid", |
| 905 | init_w=self.init_w, |
| 906 | optimizer=self.optimizer |
| 907 | ) |
| 908 | self.layers["FC5"] = FullyConnected( |
| 909 | n_out=self.fc5_out, |
| 910 | acti_fn="affine(slope=1, intercept=0)", |
| 911 | init_w=self.init_w, |
| 912 | optimizer=self.optimizer |
| 913 | ) |
| 914 | self.is_initialized = True |
| 915 | |
| 916 | def forward(self, X_train): |
no test coverage detected