Implementation of SMPLify, use 3D joints.
| 45 | |
| 46 | # SMPLIfy 3D |
| 47 | class SMPLify3D(): |
| 48 | """Implementation of SMPLify, use 3D joints.""" |
| 49 | |
| 50 | def __init__(self, |
| 51 | smplxmodel, |
| 52 | step_size=1e-2, |
| 53 | batch_size=1, |
| 54 | num_iters=100, |
| 55 | use_collision=False, |
| 56 | use_lbfgs=True, |
| 57 | joints_category="orig", |
| 58 | device=torch.device('cuda:0'), |
| 59 | ): |
| 60 | |
| 61 | # Store options |
| 62 | self.batch_size = batch_size |
| 63 | self.device = device |
| 64 | self.step_size = step_size |
| 65 | |
| 66 | self.num_iters = num_iters |
| 67 | # --- choose optimizer |
| 68 | self.use_lbfgs = use_lbfgs |
| 69 | # GMM pose prior |
| 70 | self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, |
| 71 | num_gaussians=8, |
| 72 | dtype=torch.float32).to(device) |
| 73 | # collision part |
| 74 | self.use_collision = use_collision |
| 75 | if self.use_collision: |
| 76 | self.part_segm_fn = config.Part_Seg_DIR |
| 77 | |
| 78 | # reLoad SMPL-X model |
| 79 | self.smpl = smplxmodel |
| 80 | |
| 81 | self.model_faces = smplxmodel.faces_tensor.view(-1) |
| 82 | |
| 83 | # select joint joint_category |
| 84 | self.joints_category = joints_category |
| 85 | |
| 86 | if joints_category=="orig": |
| 87 | self.smpl_index = config.full_smpl_idx |
| 88 | self.corr_index = config.full_smpl_idx |
| 89 | elif joints_category=="AMASS": |
| 90 | self.smpl_index = config.amass_smpl_idx |
| 91 | self.corr_index = config.amass_idx |
| 92 | # elif joints_category=="MMM": |
| 93 | # self.smpl_index = config.mmm_smpl_dix |
| 94 | # self.corr_index = config.mmm_idx |
| 95 | else: |
| 96 | self.smpl_index = None |
| 97 | self.corr_index = None |
| 98 | print("NO SUCH JOINTS CATEGORY!") |
| 99 | |
| 100 | # ---- get the man function here ------ |
| 101 | def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): |
| 102 | """Perform body fitting. |
| 103 | Input: |
| 104 | init_pose: SMPL pose estimate |