MCPcopy
hub / github.com/XingangPan/DragGAN / __init__

Method __init__

stylegan_human/insetgan.py:20–57  ·  view source on GitHub ↗
(self, stylebody_ckpt, styleface_ckpt)

Source from the content-addressed store, hash-verified

18
19class InsetGAN(torch.nn.Module):
20 def __init__(self, stylebody_ckpt, styleface_ckpt):
21 super().__init__()
22
23 ## convert pkl to pth
24 if not os.path.exists(stylebody_ckpt.replace('.pkl','.pth')):
25 legacy.convert(stylebody_ckpt, stylebody_ckpt.replace('.pkl','.pth'))
26 stylebody_ckpt = stylebody_ckpt.replace('.pkl','.pth')
27
28 if not os.path.exists(styleface_ckpt.replace('.pkl','.pth')):
29 legacy.convert(styleface_ckpt, styleface_ckpt.replace('.pkl','.pth'))
30 styleface_ckpt = styleface_ckpt.replace('.pkl','.pth')
31
32 # dual generator
33 config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2}
34 self.body_generator = bodyGAN(
35 size = 1024,
36 style_dim=config["latent"],
37 n_mlp=config["n_mlp"],
38 channel_multiplier=config["channel_multiplier"]
39 )
40 self.body_generator.load_state_dict(torch.load(stylebody_ckpt)['g_ema'])
41 self.body_generator.eval().requires_grad_(False).cuda()
42
43 self.face_generator = FaceGAN(
44 size = 1024,
45 style_dim=config["latent"],
46 n_mlp=config["n_mlp"],
47 channel_multiplier=config["channel_multiplier"]
48 )
49 self.face_generator.load_state_dict(torch.load(styleface_ckpt)['g_ema'])
50 self.face_generator.eval().requires_grad_(False).cuda()
51 # crop function
52 self.dlib_predictor = dlib.shape_predictor('./pretrained_models/shape_predictor_68_face_landmarks.dat')
53 self.dlib_cnn_face_detector = dlib.cnn_face_detection_model_v1("pretrained_models/mmod_human_face_detector.dat")
54
55 # criterion
56 self.lpips_loss = LPIPS(net='alex').cuda().eval()
57 self.l1_loss = torch.nn.L1Loss(reduction='mean')
58
59 def loss_coarse(self, A_face, B, p1=500, p2=0.05):
60 A_face = F.interpolate(A_face, size=(64, 64), mode='area')

Callers

nothing calls this directly

Calls 2

loadMethod · 0.80
convertMethod · 0.45

Tested by

no test coverage detected