(self, x)
| 31 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) |
| 32 | |
| 33 | def forward(self, x): |
| 34 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) |
| 35 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) |
| 36 | |
| 37 | return torch.cat((bn_x, in_x), 1) |
| 38 | |
| 39 | |
| 40 | class Conv2dIBNormRelu(nn.Module): |