MCPcopy
hub / github.com/TingsongYu/PyTorch_Tutorial / Net

Class Net

Code/main_training/main.py:69–101  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

67
68
69class Net(nn.Module):
70 def __init__(self):
71 super(Net, self).__init__()
72 self.conv1 = nn.Conv2d(3, 6, 5)
73 self.pool1 = nn.MaxPool2d(2, 2)
74 self.conv2 = nn.Conv2d(6, 16, 5)
75 self.pool2 = nn.MaxPool2d(2, 2)
76 self.fc1 = nn.Linear(16 * 5 * 5, 120)
77 self.fc2 = nn.Linear(120, 84)
78 self.fc3 = nn.Linear(84, 10)
79
80 def forward(self, x):
81 x = self.pool1(F.relu(self.conv1(x)))
82 x = self.pool2(F.relu(self.conv2(x)))
83 x = x.view(-1, 16 * 5 * 5)
84 x = F.relu(self.fc1(x))
85 x = F.relu(self.fc2(x))
86 x = self.fc3(x)
87 return x
88
89 # 定义权值初始化
90 def initialize_weights(self):
91 for m in self.modules():
92 if isinstance(m, nn.Conv2d):
93 torch.nn.init.xavier_normal_(m.weight.data)
94 if m.bias is not None:
95 m.bias.data.zero_()
96 elif isinstance(m, nn.BatchNorm2d):
97 m.weight.data.fill_(1)
98 m.bias.data.zero_()
99 elif isinstance(m, nn.Linear):
100 torch.nn.init.normal_(m.weight.data, 0, 0.01)
101 m.bias.data.zero_()
102
103
104net = Net() # 创建一个网络

Callers 1

main.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected