| 29 | |
| 30 | |
| 31 | class SampleConvNet(nn.Module): |
| 32 | def __init__(self): |
| 33 | super().__init__() |
| 34 | self.conv1 = nn.Conv2d(3, 16, 8, 2, padding=3) |
| 35 | self.conv2 = nn.Conv2d(16, 32, 4, 2) |
| 36 | self.fc1 = nn.Linear(32 * 4 * 4, 32) |
| 37 | self.fc2 = nn.Linear(32, 10) |
| 38 | |
| 39 | def forward(self, x): |
| 40 | # x of shape [B, 3, 28, 28] |
| 41 | x = F.relu(self.conv1(x)) # -> [B, 16, 14, 14] |
| 42 | x = F.max_pool2d(x, 2, 1) # -> [B, 16, 13, 13] |
| 43 | x = F.relu(self.conv2(x)) # -> [B, 32, 5, 5] |
| 44 | x = F.max_pool2d(x, 2, 1) # -> [B, 32, 4, 4] |
| 45 | x = x.view(-1, 32 * 4 * 4) # -> [B, 512] |
| 46 | x = F.relu(self.fc1(x)) # -> [B, 32] |
| 47 | x = self.fc2(x) # -> [B, 10] |
| 48 | return x |
| 49 | |
| 50 | def name(self): |
| 51 | return "SampleConvNet" |
| 52 | |
| 53 | |
| 54 | class GradSampleHooksTest(unittest.TestCase): |
no outgoing calls