(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs)
| 668 | """ |
| 669 | |
| 670 | def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs): |
| 671 | self.low_res_diffusion = low_res_diffusion |
| 672 | self.interpolate_mode = interpolate_mode |
| 673 | super().__init__(*args, **kwargs) |
| 674 | |
| 675 | self.aug_proj = nn.Sequential( |
| 676 | linear(self.model_channels, self.time_embed_dim, dtype=self.dtype), |
| 677 | get_activation(kwargs['activation']), |
| 678 | linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), |
| 679 | ) |
| 680 | |
| 681 | def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): |
| 682 | bs, _, new_height, new_width = x.shape |
nothing calls this directly
no test coverage detected