A UNetModel that performs super-resolution. Expects an extra kwarg `low_res` to condition on a low-resolution image.
| 664 | |
| 665 | |
| 666 | class SuperResModel(UNetModel): |
| 667 | """ |
| 668 | A UNetModel that performs super-resolution. |
| 669 | |
| 670 | Expects an extra kwarg `low_res` to condition on a low-resolution image. |
| 671 | """ |
| 672 | |
| 673 | def __init__(self, image_size, in_channels, *args, **kwargs): |
| 674 | super().__init__(image_size, in_channels * 2, *args, **kwargs) |
| 675 | |
| 676 | def forward(self, x, timesteps, low_res=None, **kwargs): |
| 677 | _, _, new_height, new_width = x.shape |
| 678 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") |
| 679 | x = th.cat([x, upsampled], dim=1) |
| 680 | return super().forward(x, timesteps, **kwargs) |
| 681 | |
| 682 | |
| 683 | class EncoderUNetModel(nn.Module): |