| 87 | import torch.nn.functional as F |
| 88 | |
| 89 | class Net(nn.Module): |
| 90 | def __init__(self): |
| 91 | super(Net, self).__init__() |
| 92 | self.conv1 = nn.Conv2d(3, 6, 5) |
| 93 | self.pool = nn.MaxPool2d(2, 2) |
| 94 | self.conv2 = nn.Conv2d(6, 16, 5) |
| 95 | self.fc1 = nn.Linear(16 * 5 * 5, 120) |
| 96 | self.fc2 = nn.Linear(120, 84) |
| 97 | self.fc3 = nn.Linear(84, 10) |
| 98 | |
| 99 | def forward(self, x): |
| 100 | x = self.pool(F.relu(self.conv1(x))) |
| 101 | x = self.pool(F.relu(self.conv2(x))) |
| 102 | x = x.view(-1, 16 * 5 * 5) |
| 103 | x = F.relu(self.fc1(x)) |
| 104 | x = F.relu(self.fc2(x)) |
| 105 | x = self.fc3(x) |
| 106 | return x |
| 107 | |
| 108 | net = Net() |
| 109 | parameters = filter(lambda p: p.requires_grad, net.parameters()) |