MCPcopy
hub / github.com/thygate/stable-diffusion-webui-depthmap-script / Inpaint_Depth_Net

Class Inpaint_Depth_Net

inpaint/networks.py:134–235  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

132 return h, h_mask
133
134class Inpaint_Depth_Net(nn.Module):
135 def __init__(self, layer_size=7, upsampling_mode='nearest'):
136 super().__init__()
137 in_channels = 4
138 out_channels = 1
139 self.freeze_enc_bn = False
140 self.upsampling_mode = upsampling_mode
141 self.layer_size = layer_size
142 self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True)
143 self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True)
144 self.enc_3 = PCBActiv(128, 256, sample='down-5')
145 self.enc_4 = PCBActiv(256, 512, sample='down-3')
146 for i in range(4, self.layer_size):
147 name = 'enc_{:d}'.format(i + 1)
148 setattr(self, name, PCBActiv(512, 512, sample='down-3'))
149
150 for i in range(4, self.layer_size):
151 name = 'dec_{:d}'.format(i + 1)
152 setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky'))
153 self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky')
154 self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky')
155 self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky')
156 self.dec_1 = PCBActiv(64 + in_channels, out_channels,
157 bn=False, activ=None, conv_bias=True)
158 def add_border(self, input, mask_flag, PCONV=True):
159 with torch.no_grad():
160 h = input.shape[-2]
161 w = input.shape[-1]
162 require_len_unit = 2 ** self.layer_size
163 residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
164 residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
165 enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
166 if mask_flag:
167 if PCONV is False:
168 enlarge_input += 1.0
169 enlarge_input = enlarge_input.clamp(0.0, 1.0)
170 else:
171 enlarge_input[:, 2, ...] = 0.0
172 anchor_h = residual_h//2
173 anchor_w = residual_w//2
174 enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
175
176 return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w]
177
178 def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None):
179 with torch.no_grad():
180 input = torch.cat((depth, edge, context, mask), dim=1)
181 n, c, h, w = input.shape
182 residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h)
183 residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w)
184 anchor_h = residual_h//2
185 anchor_w = residual_w//2
186 enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda)
187 enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input
188 # enlarge_input[:, 3] = 1. - enlarge_input[:, 3]
189 depth_output = self.forward(enlarge_input)
190 depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w]
191 # import pdb; pdb.set_trace()

Callers 2

run_3dphotoFunction · 0.90
main.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected