| 74 | |
| 75 | class ResNet(nn.Module): |
| 76 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): |
| 77 | super(ResNet, self).__init__() |
| 78 | self.in_planes = 64 |
| 79 | |
| 80 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, |
| 81 | bias=False) |
| 82 | self.bn1 = nn.BatchNorm2d(64) |
| 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) |
| 84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) |
| 85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) |
| 86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) |
| 87 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| 88 | |
| 89 | for m in self.modules(): |
| 90 | if isinstance(m, nn.Conv2d): |
| 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| 93 | nn.init.constant_(m.weight, 1) |
| 94 | nn.init.constant_(m.bias, 0) |
| 95 | |
| 96 | # Zero-initialize the last BN in each residual branch, |
| 97 | # so that the residual branch starts with zeros, and each residual block behaves |
| 98 | # like an identity. This improves the model by 0.2~0.3% according to: |
| 99 | # https://arxiv.org/abs/1706.02677 |
| 100 | if zero_init_residual: |
| 101 | for m in self.modules(): |
| 102 | if isinstance(m, Bottleneck): |
| 103 | nn.init.constant_(m.bn3.weight, 0) |
| 104 | elif isinstance(m, BasicBlock): |
| 105 | nn.init.constant_(m.bn2.weight, 0) |
| 106 | |
| 107 | def _make_layer(self, block, planes, num_blocks, stride): |
| 108 | strides = [stride] + [1] * (num_blocks - 1) |