(x, factor=2, gain=1)
| 173 | |
| 174 | |
| 175 | def upscale2d(x, factor=2, gain=1): |
| 176 | assert x.dim() == 4 |
| 177 | if gain != 1: |
| 178 | x = x * gain |
| 179 | if factor != 1: |
| 180 | shape = x.shape |
| 181 | x = x.view(shape[0], shape[1], shape[2], 1, shape[3], |
| 182 | 1).expand(-1, -1, -1, factor, -1, factor) |
| 183 | x = x.contiguous().view( |
| 184 | shape[0], shape[1], factor * shape[2], factor * shape[3]) |
| 185 | return x |
| 186 | |
| 187 | |
| 188 | class Upscale2d(nn.Module): |