| 4 | |
| 5 | |
| 6 | class ResidualBlock(nn.Module): |
| 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): |
| 8 | super(ResidualBlock, self).__init__() |
| 9 | |
| 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) |
| 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) |
| 12 | self.relu = nn.ReLU(inplace=True) |
| 13 | |
| 14 | num_groups = planes // 8 |
| 15 | |
| 16 | if norm_fn == 'group': |
| 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| 19 | if not stride == 1: |
| 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| 21 | |
| 22 | elif norm_fn == 'batch': |
| 23 | self.norm1 = nn.BatchNorm2d(planes) |
| 24 | self.norm2 = nn.BatchNorm2d(planes) |
| 25 | if not stride == 1: |
| 26 | self.norm3 = nn.BatchNorm2d(planes) |
| 27 | |
| 28 | elif norm_fn == 'instance': |
| 29 | self.norm1 = nn.InstanceNorm2d(planes) |
| 30 | self.norm2 = nn.InstanceNorm2d(planes) |
| 31 | if not stride == 1: |
| 32 | self.norm3 = nn.InstanceNorm2d(planes) |
| 33 | |
| 34 | elif norm_fn == 'none': |
| 35 | self.norm1 = nn.Sequential() |
| 36 | self.norm2 = nn.Sequential() |
| 37 | if not stride == 1: |
| 38 | self.norm3 = nn.Sequential() |
| 39 | |
| 40 | if stride == 1: |
| 41 | self.downsample = None |
| 42 | |
| 43 | else: |
| 44 | self.downsample = nn.Sequential( |
| 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) |
| 46 | |
| 47 | |
| 48 | def forward(self, x): |
| 49 | y = x |
| 50 | y = self.relu(self.norm1(self.conv1(y))) |
| 51 | y = self.relu(self.norm2(self.conv2(y))) |
| 52 | |
| 53 | if self.downsample is not None: |
| 54 | x = self.downsample(x) |
| 55 | |
| 56 | return self.relu(x+y) |
| 57 | |
| 58 | |
| 59 | |