Interpolation module.
| 209 | |
| 210 | |
| 211 | class Interpolate(nn.Module): |
| 212 | """Interpolation module. |
| 213 | """ |
| 214 | |
| 215 | def __init__(self, scale_factor, mode, align_corners=False): |
| 216 | """Init. |
| 217 | |
| 218 | Args: |
| 219 | scale_factor (float): scaling |
| 220 | mode (str): interpolation mode |
| 221 | """ |
| 222 | super(Interpolate, self).__init__() |
| 223 | |
| 224 | self.interp = nn.functional.interpolate |
| 225 | self.scale_factor = scale_factor |
| 226 | self.mode = mode |
| 227 | self.align_corners = align_corners |
| 228 | |
| 229 | def forward(self, x): |
| 230 | """Forward pass. |
| 231 | |
| 232 | Args: |
| 233 | x (tensor): input |
| 234 | |
| 235 | Returns: |
| 236 | tensor: interpolated data |
| 237 | """ |
| 238 | |
| 239 | x = self.interp( |
| 240 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners |
| 241 | ) |
| 242 | |
| 243 | return x |
| 244 | |
| 245 | |
| 246 | class ResidualConvUnit(nn.Module): |