MCPcopy
hub / github.com/TingsongYu/PyTorch_Tutorial / Net

Class Net

Code/2_model/2_finetune.py:66–98  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

64
65
66class Net(nn.Module):
67 def __init__(self):
68 super(Net, self).__init__()
69 self.conv1 = nn.Conv2d(3, 6, 5)
70 self.pool1 = nn.MaxPool2d(2, 2)
71 self.conv2 = nn.Conv2d(6, 16, 5)
72 self.pool2 = nn.MaxPool2d(2, 2)
73 self.fc1 = nn.Linear(16 * 5 * 5, 120)
74 self.fc2 = nn.Linear(120, 84)
75 self.fc3 = nn.Linear(84, 10)
76
77 def forward(self, x):
78 x = self.pool1(F.relu(self.conv1(x)))
79 x = self.pool2(F.relu(self.conv2(x)))
80 x = x.view(-1, 16 * 5 * 5)
81 x = F.relu(self.fc1(x))
82 x = F.relu(self.fc2(x))
83 x = self.fc3(x)
84 return x
85
86 # 定义权值初始化
87 def initialize_weights(self):
88 for m in self.modules():
89 if isinstance(m, nn.Conv2d):
90 torch.nn.init.xavier_normal_(m.weight.data)
91 if m.bias is not None:
92 m.bias.data.zero_()
93 elif isinstance(m, nn.BatchNorm2d):
94 m.weight.data.fill_(1)
95 m.bias.data.zero_()
96 elif isinstance(m, nn.Linear):
97 torch.nn.init.normal_(m.weight.data, 0, 0.01)
98 m.bias.data.zero_()
99
100
101net = Net() # 创建一个网络

Callers 1

2_finetune.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected