MCPcopy
hub / github.com/princeton-vl/RAFT / ResidualBlock

Class ResidualBlock

core/extractor.py:6–56  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

4
5
6class ResidualBlock(nn.Module):
7 def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 super(ResidualBlock, self).__init__()
9
10 self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 self.relu = nn.ReLU(inplace=True)
13
14 num_groups = planes // 8
15
16 if norm_fn == 'group':
17 self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 if not stride == 1:
20 self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
22 elif norm_fn == 'batch':
23 self.norm1 = nn.BatchNorm2d(planes)
24 self.norm2 = nn.BatchNorm2d(planes)
25 if not stride == 1:
26 self.norm3 = nn.BatchNorm2d(planes)
27
28 elif norm_fn == 'instance':
29 self.norm1 = nn.InstanceNorm2d(planes)
30 self.norm2 = nn.InstanceNorm2d(planes)
31 if not stride == 1:
32 self.norm3 = nn.InstanceNorm2d(planes)
33
34 elif norm_fn == 'none':
35 self.norm1 = nn.Sequential()
36 self.norm2 = nn.Sequential()
37 if not stride == 1:
38 self.norm3 = nn.Sequential()
39
40 if stride == 1:
41 self.downsample = None
42
43 else:
44 self.downsample = nn.Sequential(
45 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
47
48 def forward(self, x):
49 y = x
50 y = self.relu(self.norm1(self.conv1(y)))
51 y = self.relu(self.norm2(self.conv2(y)))
52
53 if self.downsample is not None:
54 x = self.downsample(x)
55
56 return self.relu(x+y)
57
58
59

Callers 1

_make_layerMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected