| 64 | |
| 65 | |
| 66 | class Net(nn.Module): |
| 67 | def __init__(self): |
| 68 | super(Net, self).__init__() |
| 69 | self.conv1 = nn.Conv2d(3, 6, 5) |
| 70 | self.pool1 = nn.MaxPool2d(2, 2) |
| 71 | self.conv2 = nn.Conv2d(6, 16, 5) |
| 72 | self.pool2 = nn.MaxPool2d(2, 2) |
| 73 | self.fc1 = nn.Linear(16 * 5 * 5, 120) |
| 74 | self.fc2 = nn.Linear(120, 84) |
| 75 | self.fc3 = nn.Linear(84, 10) |
| 76 | |
| 77 | def forward(self, x): |
| 78 | x = self.pool1(F.relu(self.conv1(x))) |
| 79 | x = self.pool2(F.relu(self.conv2(x))) |
| 80 | x = x.view(-1, 16 * 5 * 5) |
| 81 | x = F.relu(self.fc1(x)) |
| 82 | x = F.relu(self.fc2(x)) |
| 83 | x = self.fc3(x) |
| 84 | return x |
| 85 | |
| 86 | # 定义权值初始化 |
| 87 | def initialize_weights(self): |
| 88 | for m in self.modules(): |
| 89 | if isinstance(m, nn.Conv2d): |
| 90 | torch.nn.init.xavier_normal_(m.weight.data) |
| 91 | if m.bias is not None: |
| 92 | m.bias.data.zero_() |
| 93 | elif isinstance(m, nn.BatchNorm2d): |
| 94 | m.weight.data.fill_(1) |
| 95 | m.bias.data.zero_() |
| 96 | elif isinstance(m, nn.Linear): |
| 97 | torch.nn.init.normal_(m.weight.data, 0, 0.01) |
| 98 | m.bias.data.zero_() |
| 99 | |
| 100 | |
| 101 | net = Net() # 创建一个网络 |