| 18 | |
| 19 | class InsetGAN(torch.nn.Module): |
| 20 | def __init__(self, stylebody_ckpt, styleface_ckpt): |
| 21 | super().__init__() |
| 22 | |
| 23 | ## convert pkl to pth |
| 24 | if not os.path.exists(stylebody_ckpt.replace('.pkl','.pth')): |
| 25 | legacy.convert(stylebody_ckpt, stylebody_ckpt.replace('.pkl','.pth')) |
| 26 | stylebody_ckpt = stylebody_ckpt.replace('.pkl','.pth') |
| 27 | |
| 28 | if not os.path.exists(styleface_ckpt.replace('.pkl','.pth')): |
| 29 | legacy.convert(styleface_ckpt, styleface_ckpt.replace('.pkl','.pth')) |
| 30 | styleface_ckpt = styleface_ckpt.replace('.pkl','.pth') |
| 31 | |
| 32 | # dual generator |
| 33 | config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2} |
| 34 | self.body_generator = bodyGAN( |
| 35 | size = 1024, |
| 36 | style_dim=config["latent"], |
| 37 | n_mlp=config["n_mlp"], |
| 38 | channel_multiplier=config["channel_multiplier"] |
| 39 | ) |
| 40 | self.body_generator.load_state_dict(torch.load(stylebody_ckpt)['g_ema']) |
| 41 | self.body_generator.eval().requires_grad_(False).cuda() |
| 42 | |
| 43 | self.face_generator = FaceGAN( |
| 44 | size = 1024, |
| 45 | style_dim=config["latent"], |
| 46 | n_mlp=config["n_mlp"], |
| 47 | channel_multiplier=config["channel_multiplier"] |
| 48 | ) |
| 49 | self.face_generator.load_state_dict(torch.load(styleface_ckpt)['g_ema']) |
| 50 | self.face_generator.eval().requires_grad_(False).cuda() |
| 51 | # crop function |
| 52 | self.dlib_predictor = dlib.shape_predictor('./pretrained_models/shape_predictor_68_face_landmarks.dat') |
| 53 | self.dlib_cnn_face_detector = dlib.cnn_face_detection_model_v1("pretrained_models/mmod_human_face_detector.dat") |
| 54 | |
| 55 | # criterion |
| 56 | self.lpips_loss = LPIPS(net='alex').cuda().eval() |
| 57 | self.l1_loss = torch.nn.L1Loss(reduction='mean') |
| 58 | |
| 59 | def loss_coarse(self, A_face, B, p1=500, p2=0.05): |
| 60 | A_face = F.interpolate(A_face, size=(64, 64), mode='area') |