MCPcopy
hub / github.com/meta-pytorch/opacus / SampleConvNet

Class SampleConvNet

opacus/tests/grad_sample_hooks_test.py:31–51  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

29
30
31class SampleConvNet(nn.Module):
32 def __init__(self):
33 super().__init__()
34 self.conv1 = nn.Conv2d(3, 16, 8, 2, padding=3)
35 self.conv2 = nn.Conv2d(16, 32, 4, 2)
36 self.fc1 = nn.Linear(32 * 4 * 4, 32)
37 self.fc2 = nn.Linear(32, 10)
38
39 def forward(self, x):
40 # x of shape [B, 3, 28, 28]
41 x = F.relu(self.conv1(x)) # -> [B, 16, 14, 14]
42 x = F.max_pool2d(x, 2, 1) # -> [B, 16, 13, 13]
43 x = F.relu(self.conv2(x)) # -> [B, 32, 5, 5]
44 x = F.max_pool2d(x, 2, 1) # -> [B, 32, 4, 4]
45 x = x.view(-1, 32 * 4 * 4) # -> [B, 512]
46 x = F.relu(self.fc1(x)) # -> [B, 32]
47 x = self.fc2(x) # -> [B, 10]
48 return x
49
50 def name(self):
51 return "SampleConvNet"
52
53
54class GradSampleHooksTest(unittest.TestCase):

Callers 3

setUpMethod · 0.70
test_remove_hooksMethod · 0.70
test_load_state_dictMethod · 0.70

Calls

no outgoing calls

Tested by 3

setUpMethod · 0.56
test_remove_hooksMethod · 0.56
test_load_state_dictMethod · 0.56