(self, x)
| 26 | self.stride = stride |
| 27 | |
| 28 | def forward(self, x): |
| 29 | identity = x.features |
| 30 | |
| 31 | assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim() |
| 32 | |
| 33 | out = self.conv1(x) |
| 34 | out = replace_feature(out, self.bn1(out.features)) |
| 35 | out = replace_feature(out, self.relu(out.features)) |
| 36 | |
| 37 | out = self.conv2(out) |
| 38 | out = replace_feature(out, self.bn2(out.features)) |
| 39 | |
| 40 | if self.downsample is not None: |
| 41 | identity = self.downsample(x) |
| 42 | |
| 43 | out = replace_feature(out, out.features + identity) |
| 44 | out = replace_feature(out, self.relu(out.features)) |
| 45 | |
| 46 | return out |
| 47 | |
| 48 | |
| 49 | class UNetV2(nn.Module): |
nothing calls this directly
no test coverage detected