Generate random input tensors and move them to GPU
(batch_size=1, device_id=0)
| 18 | |
| 19 | |
| 20 | def initialize_inputs(batch_size=1, device_id=0): |
| 21 | """ |
| 22 | Generate random input tensors and move them to GPU |
| 23 | """ |
| 24 | feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half() |
| 25 | kp_source = torch.randn(batch_size, 21, 3).to(device_id).half() |
| 26 | kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half() |
| 27 | source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half() |
| 28 | generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half() |
| 29 | eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half() |
| 30 | lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half() |
| 31 | feat_stitching = concat_feat(kp_source, kp_driving).half() |
| 32 | feat_eye = concat_feat(kp_source, eye_close_ratio).half() |
| 33 | feat_lip = concat_feat(kp_source, lip_close_ratio).half() |
| 34 | |
| 35 | inputs = { |
| 36 | 'feature_3d': feature_3d, |
| 37 | 'kp_source': kp_source, |
| 38 | 'kp_driving': kp_driving, |
| 39 | 'source_image': source_image, |
| 40 | 'generator_input': generator_input, |
| 41 | 'feat_stitching': feat_stitching, |
| 42 | 'feat_eye': feat_eye, |
| 43 | 'feat_lip': feat_lip |
| 44 | } |
| 45 | |
| 46 | return inputs |
| 47 | |
| 48 | |
| 49 | def load_and_compile_models(cfg, model_config): |
no test coverage detected