MCPcopy
hub / github.com/openai/guided-diffusion / SuperResModel

Class SuperResModel

guided_diffusion/unet.py:666–680  ·  view source on GitHub ↗

A UNetModel that performs super-resolution. Expects an extra kwarg `low_res` to condition on a low-resolution image.

Source from the content-addressed store, hash-verified

664
665
666class SuperResModel(UNetModel):
667 """
668 A UNetModel that performs super-resolution.
669
670 Expects an extra kwarg `low_res` to condition on a low-resolution image.
671 """
672
673 def __init__(self, image_size, in_channels, *args, **kwargs):
674 super().__init__(image_size, in_channels * 2, *args, **kwargs)
675
676 def forward(self, x, timesteps, low_res=None, **kwargs):
677 _, _, new_height, new_width = x.shape
678 upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
679 x = th.cat([x, upsampled], dim=1)
680 return super().forward(x, timesteps, **kwargs)
681
682
683class EncoderUNetModel(nn.Module):

Callers 1

sr_create_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected