| 14 | |
| 15 | |
| 16 | class PredictorLocal: |
| 17 | def __init__(self, config_path, checkpoint_path, relative=False, adapt_movement_scale=False, device=None, enc_downscale=1): |
| 18 | self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') |
| 19 | self.relative = relative |
| 20 | self.adapt_movement_scale = adapt_movement_scale |
| 21 | self.start_frame = None |
| 22 | self.start_frame_kp = None |
| 23 | self.kp_driving_initial = None |
| 24 | self.config_path = config_path |
| 25 | self.checkpoint_path = checkpoint_path |
| 26 | self.generator, self.kp_detector = self.load_checkpoints() |
| 27 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, device=self.device) |
| 28 | self.source = None |
| 29 | self.kp_source = None |
| 30 | self.enc_downscale = enc_downscale |
| 31 | |
| 32 | def load_checkpoints(self): |
| 33 | with open(self.config_path) as f: |
| 34 | config = yaml.load(f) |
| 35 | |
| 36 | generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], |
| 37 | **config['model_params']['common_params']) |
| 38 | generator.to(self.device) |
| 39 | |
| 40 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'], |
| 41 | **config['model_params']['common_params']) |
| 42 | kp_detector.to(self.device) |
| 43 | |
| 44 | checkpoint = torch.load(self.checkpoint_path, map_location=self.device) |
| 45 | generator.load_state_dict(checkpoint['generator']) |
| 46 | kp_detector.load_state_dict(checkpoint['kp_detector']) |
| 47 | |
| 48 | generator.eval() |
| 49 | kp_detector.eval() |
| 50 | |
| 51 | return generator, kp_detector |
| 52 | |
| 53 | def reset_frames(self): |
| 54 | self.kp_driving_initial = None |
| 55 | |
| 56 | def set_source_image(self, source_image): |
| 57 | self.source = to_tensor(source_image).to(self.device) |
| 58 | self.kp_source = self.kp_detector(self.source) |
| 59 | |
| 60 | if self.enc_downscale > 1: |
| 61 | h, w = int(self.source.shape[2] / self.enc_downscale), int(self.source.shape[3] / self.enc_downscale) |
| 62 | source_enc = torch.nn.functional.interpolate(self.source, size=(h, w), mode='bilinear') |
| 63 | else: |
| 64 | source_enc = self.source |
| 65 | |
| 66 | self.generator.encode_source(source_enc) |
| 67 | |
| 68 | def predict(self, driving_frame): |
| 69 | assert self.kp_source is not None, "call set_source_image()" |
| 70 | |
| 71 | with torch.no_grad(): |
| 72 | driving = to_tensor(driving_frame).to(self.device) |
| 73 | |