MCPcopy
hub / github.com/TingsongYu/PyTorch_Tutorial / Net

Class Net

Code/utils/utils.py:13–45  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

11
12
13class Net(nn.Module):
14 def __init__(self):
15 super(Net, self).__init__()
16 self.conv1 = nn.Conv2d(3, 6, 5)
17 self.pool1 = nn.MaxPool2d(2, 2)
18 self.conv2 = nn.Conv2d(6, 16, 5)
19 self.pool2 = nn.MaxPool2d(2, 2)
20 self.fc1 = nn.Linear(16 * 5 * 5, 120)
21 self.fc2 = nn.Linear(120, 84)
22 self.fc3 = nn.Linear(84, 10)
23
24 def forward(self, x):
25 x = self.pool1(F.relu(self.conv1(x)))
26 x = self.pool2(F.relu(self.conv2(x)))
27 x = x.view(-1, 16 * 5 * 5)
28 x = F.relu(self.fc1(x))
29 x = F.relu(self.fc2(x))
30 x = self.fc3(x)
31 return x
32
33 # 定义权值初始化
34 def initialize_weights(self):
35 for m in self.modules():
36 if isinstance(m, nn.Conv2d):
37 torch.nn.init.xavier_normal_(m.weight.data)
38 if m.bias is not None:
39 m.bias.data.zero_()
40 elif isinstance(m, nn.BatchNorm2d):
41 m.weight.data.fill_(1)
42 m.bias.data.zero_()
43 elif isinstance(m, nn.Linear):
44 torch.nn.init.normal_(m.weight.data, 0, 0.01)
45 m.bias.data.zero_()
46
47class MyDataset(Dataset):
48 def __init__(self, txt_path, transform = None, target_transform = None):

Callers 2

Calls

no outgoing calls

Tested by

no test coverage detected