| 132 | return h, h_mask |
| 133 | |
| 134 | class Inpaint_Depth_Net(nn.Module): |
| 135 | def __init__(self, layer_size=7, upsampling_mode='nearest'): |
| 136 | super().__init__() |
| 137 | in_channels = 4 |
| 138 | out_channels = 1 |
| 139 | self.freeze_enc_bn = False |
| 140 | self.upsampling_mode = upsampling_mode |
| 141 | self.layer_size = layer_size |
| 142 | self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True) |
| 143 | self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True) |
| 144 | self.enc_3 = PCBActiv(128, 256, sample='down-5') |
| 145 | self.enc_4 = PCBActiv(256, 512, sample='down-3') |
| 146 | for i in range(4, self.layer_size): |
| 147 | name = 'enc_{:d}'.format(i + 1) |
| 148 | setattr(self, name, PCBActiv(512, 512, sample='down-3')) |
| 149 | |
| 150 | for i in range(4, self.layer_size): |
| 151 | name = 'dec_{:d}'.format(i + 1) |
| 152 | setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky')) |
| 153 | self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky') |
| 154 | self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky') |
| 155 | self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky') |
| 156 | self.dec_1 = PCBActiv(64 + in_channels, out_channels, |
| 157 | bn=False, activ=None, conv_bias=True) |
| 158 | def add_border(self, input, mask_flag, PCONV=True): |
| 159 | with torch.no_grad(): |
| 160 | h = input.shape[-2] |
| 161 | w = input.shape[-1] |
| 162 | require_len_unit = 2 ** self.layer_size |
| 163 | residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit |
| 164 | residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit |
| 165 | enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) |
| 166 | if mask_flag: |
| 167 | if PCONV is False: |
| 168 | enlarge_input += 1.0 |
| 169 | enlarge_input = enlarge_input.clamp(0.0, 1.0) |
| 170 | else: |
| 171 | enlarge_input[:, 2, ...] = 0.0 |
| 172 | anchor_h = residual_h//2 |
| 173 | anchor_w = residual_w//2 |
| 174 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
| 175 | |
| 176 | return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] |
| 177 | |
| 178 | def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None): |
| 179 | with torch.no_grad(): |
| 180 | input = torch.cat((depth, edge, context, mask), dim=1) |
| 181 | n, c, h, w = input.shape |
| 182 | residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) |
| 183 | residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) |
| 184 | anchor_h = residual_h//2 |
| 185 | anchor_w = residual_w//2 |
| 186 | enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) |
| 187 | enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
| 188 | # enlarge_input[:, 3] = 1. - enlarge_input[:, 3] |
| 189 | depth_output = self.forward(enlarge_input) |
| 190 | depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] |
| 191 | # import pdb; pdb.set_trace() |
no outgoing calls
no test coverage detected