MCPcopy
hub / github.com/thygate/stable-diffusion-webui-depthmap-script / Inpaint_Edge_Net

Class Inpaint_Edge_Net

inpaint/networks.py:237–330  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

235 return output
236
237class Inpaint_Edge_Net(BaseNetwork):
238 def __init__(self, residual_blocks=8, init_weights=True):
239 super(Inpaint_Edge_Net, self).__init__()
240 in_channels = 7
241 out_channels = 1
242 self.encoder = []
243 # 0
244 self.encoder_0 = nn.Sequential(
245 nn.ReflectionPad2d(3),
246 spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True),
247 nn.InstanceNorm2d(64, track_running_stats=False),
248 nn.ReLU(True))
249 # 1
250 self.encoder_1 = nn.Sequential(
251 spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True),
252 nn.InstanceNorm2d(128, track_running_stats=False),
253 nn.ReLU(True))
254 # 2
255 self.encoder_2 = nn.Sequential(
256 spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True),
257 nn.InstanceNorm2d(256, track_running_stats=False),
258 nn.ReLU(True))
259 # 3
260 blocks = []
261 for _ in range(residual_blocks):
262 block = ResnetBlock(256, 2)
263 blocks.append(block)
264
265 self.middle = nn.Sequential(*blocks)
266 # + 3
267 self.decoder_0 = nn.Sequential(
268 spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True),
269 nn.InstanceNorm2d(128, track_running_stats=False),
270 nn.ReLU(True))
271 # + 2
272 self.decoder_1 = nn.Sequential(
273 spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True),
274 nn.InstanceNorm2d(64, track_running_stats=False),
275 nn.ReLU(True))
276 # + 1
277 self.decoder_2 = nn.Sequential(
278 nn.ReflectionPad2d(3),
279 nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0),
280 )
281
282 if init_weights:
283 self.init_weights()
284
285 def add_border(self, input, channel_pad_1=None):
286 h = input.shape[-2]
287 w = input.shape[-1]
288 require_len_unit = 16
289 residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) # + 2*require_len_unit
290 residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) # + 2*require_len_unit
291 enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device)
292 if channel_pad_1 is not None:
293 for channel in channel_pad_1:
294 enlarge_input[:, channel] = 1

Callers 2

run_3dphotoFunction · 0.90
main.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected