Generic implementation of residual blocks. This implements a generic residual block from He et al. - Identity Mappings in Deep Residual Networks (2016), https://arxiv.org/abs/1603.05027 which can be further customized via factory functions.
| 94 | |
| 95 | |
| 96 | class ResidualBlock(nn.Module): |
| 97 | """Generic implementation of residual blocks. |
| 98 | |
| 99 | This implements a generic residual block from |
| 100 | He et al. - Identity Mappings in Deep Residual Networks (2016), |
| 101 | https://arxiv.org/abs/1603.05027 |
| 102 | which can be further customized via factory functions. |
| 103 | """ |
| 104 | |
| 105 | def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None: |
| 106 | """Initialize ResidualBlock.""" |
| 107 | super().__init__() |
| 108 | self.residual = residual |
| 109 | self.shortcut = shortcut |
| 110 | |
| 111 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 112 | """Apply residual block.""" |
| 113 | delta_x = self.residual(x) |
| 114 | |
| 115 | if self.shortcut is not None: |
| 116 | x = self.shortcut(x) |
| 117 | |
| 118 | return x + delta_x |
| 119 | |
| 120 | |
| 121 | class FeatureFusionBlock2d(nn.Module): |