(
self,
in_channels,
scale_factor=2,
width=128,
height=32,
STN=False,
srb_nums=5,
mask=False,
hidden_units=32,
infer_mode=False,
**kwargs,
)
| 36 | |
| 37 | class TSRN(nn.Layer): |
| 38 | def __init__( |
| 39 | self, |
| 40 | in_channels, |
| 41 | scale_factor=2, |
| 42 | width=128, |
| 43 | height=32, |
| 44 | STN=False, |
| 45 | srb_nums=5, |
| 46 | mask=False, |
| 47 | hidden_units=32, |
| 48 | infer_mode=False, |
| 49 | **kwargs, |
| 50 | ): |
| 51 | super(TSRN, self).__init__() |
| 52 | in_planes = 3 |
| 53 | if mask: |
| 54 | in_planes = 4 |
| 55 | assert math.log(scale_factor, 2) % 1 == 0 |
| 56 | upsample_block_num = int(math.log(scale_factor, 2)) |
| 57 | self.block1 = nn.Sequential( |
| 58 | nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4), nn.PReLU() |
| 59 | ) |
| 60 | self.srb_nums = srb_nums |
| 61 | for i in range(srb_nums): |
| 62 | setattr(self, "block%d" % (i + 2), RecurrentResidualBlock(2 * hidden_units)) |
| 63 | |
| 64 | setattr( |
| 65 | self, |
| 66 | "block%d" % (srb_nums + 2), |
| 67 | nn.Sequential( |
| 68 | nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1), |
| 69 | nn.BatchNorm2D(2 * hidden_units), |
| 70 | ), |
| 71 | ) |
| 72 | |
| 73 | block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)] |
| 74 | block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4)) |
| 75 | setattr(self, "block%d" % (srb_nums + 3), nn.Sequential(*block_)) |
| 76 | self.tps_inputsize = [height // scale_factor, width // scale_factor] |
| 77 | tps_outputsize = [height // scale_factor, width // scale_factor] |
| 78 | num_control_points = 20 |
| 79 | tps_margins = [0.05, 0.05] |
| 80 | self.stn = STN |
| 81 | if self.stn: |
| 82 | self.tps = TPSSpatialTransformer( |
| 83 | output_image_size=tuple(tps_outputsize), |
| 84 | num_control_points=num_control_points, |
| 85 | margins=tuple(tps_margins), |
| 86 | ) |
| 87 | |
| 88 | self.stn_head = STN_model( |
| 89 | in_channels=in_planes, |
| 90 | num_ctrlpoints=num_control_points, |
| 91 | activation="none", |
| 92 | ) |
| 93 | self.out_channels = in_channels |
| 94 | |
| 95 | self.r34_transformer = Transformer() |
no test coverage detected