Create a residual block.
(num_features: int, batch_norm: bool)
| 181 | |
| 182 | @staticmethod |
| 183 | def _residual_block(num_features: int, batch_norm: bool): |
| 184 | """Create a residual block.""" |
| 185 | |
| 186 | def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]: |
| 187 | layers = [ |
| 188 | nn.ReLU(False), |
| 189 | nn.Conv2d( |
| 190 | num_features, |
| 191 | num_features, |
| 192 | kernel_size=3, |
| 193 | stride=1, |
| 194 | padding=1, |
| 195 | bias=not batch_norm, |
| 196 | ), |
| 197 | ] |
| 198 | if batch_norm: |
| 199 | layers.append(nn.BatchNorm2d(dim)) |
| 200 | return layers |
| 201 | |
| 202 | residual = nn.Sequential( |
| 203 | *_create_block(dim=num_features, batch_norm=batch_norm), |
| 204 | *_create_block(dim=num_features, batch_norm=batch_norm), |
| 205 | ) |
| 206 | return ResidualBlock(residual) |