| 250 | return target_height, target_width |
| 251 | |
| 252 | def prepare_default_cond_input(self, |
| 253 | map_shape=[3, 12, 64, 64], |
| 254 | motion_frames=5, |
| 255 | lat_motion_frames=2, |
| 256 | enable_mano=False, |
| 257 | enable_kp=False, |
| 258 | enable_pose=False): |
| 259 | default_value = [1.0, -1.0, -1.0] |
| 260 | cond_enable = [enable_mano, enable_kp, enable_pose] |
| 261 | cond = [] |
| 262 | for d, c in zip(default_value, cond_enable): |
| 263 | if c: |
| 264 | map_value = torch.ones( |
| 265 | map_shape, dtype=self.param_dtype, device=self.device) * d |
| 266 | cond_lat = torch.cat([ |
| 267 | map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1), |
| 268 | map_value |
| 269 | ], |
| 270 | dim=2) |
| 271 | cond_lat = torch.stack( |
| 272 | self.vae.encode(cond_lat.to( |
| 273 | self.param_dtype)))[:, :, lat_motion_frames:].to( |
| 274 | self.param_dtype) |
| 275 | |
| 276 | cond.append(cond_lat) |
| 277 | if len(cond) >= 1: |
| 278 | cond = torch.cat(cond, dim=1) |
| 279 | else: |
| 280 | cond = None |
| 281 | return cond |
| 282 | |
| 283 | def encode_audio(self, audio_path, infer_frames): |
| 284 | z = self.audio_encoder.extract_audio_feat( |