MCPcopy Index your code
hub / github.com/alievk/avatarify-python / PredictorLocal

Class PredictorLocal

afy/predictor_local.py:16–112  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

14
15
16class PredictorLocal:
17 def __init__(self, config_path, checkpoint_path, relative=False, adapt_movement_scale=False, device=None, enc_downscale=1):
18 self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
19 self.relative = relative
20 self.adapt_movement_scale = adapt_movement_scale
21 self.start_frame = None
22 self.start_frame_kp = None
23 self.kp_driving_initial = None
24 self.config_path = config_path
25 self.checkpoint_path = checkpoint_path
26 self.generator, self.kp_detector = self.load_checkpoints()
27 self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, device=self.device)
28 self.source = None
29 self.kp_source = None
30 self.enc_downscale = enc_downscale
31
32 def load_checkpoints(self):
33 with open(self.config_path) as f:
34 config = yaml.load(f)
35
36 generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
37 **config['model_params']['common_params'])
38 generator.to(self.device)
39
40 kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
41 **config['model_params']['common_params'])
42 kp_detector.to(self.device)
43
44 checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
45 generator.load_state_dict(checkpoint['generator'])
46 kp_detector.load_state_dict(checkpoint['kp_detector'])
47
48 generator.eval()
49 kp_detector.eval()
50
51 return generator, kp_detector
52
53 def reset_frames(self):
54 self.kp_driving_initial = None
55
56 def set_source_image(self, source_image):
57 self.source = to_tensor(source_image).to(self.device)
58 self.kp_source = self.kp_detector(self.source)
59
60 if self.enc_downscale > 1:
61 h, w = int(self.source.shape[2] / self.enc_downscale), int(self.source.shape[3] / self.enc_downscale)
62 source_enc = torch.nn.functional.interpolate(self.source, size=(h, w), mode='bilinear')
63 else:
64 source_enc = self.source
65
66 self.generator.encode_source(source_enc)
67
68 def predict(self, driving_frame):
69 assert self.kp_source is not None, "call set_source_image()"
70
71 with torch.no_grad():
72 driving = to_tensor(driving_frame).to(self.device)
73

Callers 1

predictor_workerMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected