Computes the frames for the up to 8 rigid groups for each residue.
(
aatype: jnp.ndarray, # (...)
all_atom_positions: geometry.Vec3Array, # (..., 37)
all_atom_mask: jnp.ndarray, # (..., 37)
)
| 273 | |
| 274 | |
| 275 | def atom37_to_frames( |
| 276 | aatype: jnp.ndarray, # (...) |
| 277 | all_atom_positions: geometry.Vec3Array, # (..., 37) |
| 278 | all_atom_mask: jnp.ndarray, # (..., 37) |
| 279 | ) -> Dict[str, jnp.ndarray]: |
| 280 | """Computes the frames for the up to 8 rigid groups for each residue.""" |
| 281 | # 0: 'backbone group', |
| 282 | # 1: 'pre-omega-group', (empty) |
| 283 | # 2: 'phi-group', (currently empty, because it defines only hydrogens) |
| 284 | # 3: 'psi-group', |
| 285 | # 4,5,6,7: 'chi1,2,3,4-group' |
| 286 | aatype_in_shape = aatype.shape |
| 287 | |
| 288 | # If there is a batch axis, just flatten it away, and reshape everything |
| 289 | # back at the end of the function. |
| 290 | aatype = jnp.reshape(aatype, [-1]) |
| 291 | all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]), |
| 292 | all_atom_positions) |
| 293 | all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) |
| 294 | |
| 295 | # Compute the gather indices for all residues in the chain. |
| 296 | # shape (N, 8, 3) |
| 297 | residx_rigidgroup_base_atom37_idx = utils.batched_gather( |
| 298 | RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) |
| 299 | |
| 300 | # Gather the base atom positions for each rigid group. |
| 301 | base_atom_pos = jax.tree_map( |
| 302 | lambda x: utils.batched_gather( # pylint: disable=g-long-lambda |
| 303 | x, residx_rigidgroup_base_atom37_idx, batch_dims=1), |
| 304 | all_atom_positions) |
| 305 | |
| 306 | # Compute the Rigids. |
| 307 | point_on_neg_x_axis = base_atom_pos[:, :, 0] |
| 308 | origin = base_atom_pos[:, :, 1] |
| 309 | point_on_xy_plane = base_atom_pos[:, :, 2] |
| 310 | gt_rotation = geometry.Rot3Array.from_two_vectors( |
| 311 | origin - point_on_neg_x_axis, point_on_xy_plane - origin) |
| 312 | |
| 313 | gt_frames = geometry.Rigid3Array(gt_rotation, origin) |
| 314 | |
| 315 | # Compute a mask whether the group exists. |
| 316 | # (N, 8) |
| 317 | group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype) |
| 318 | |
| 319 | # Compute a mask whether ground truth exists for the group |
| 320 | gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) |
| 321 | all_atom_mask.astype(jnp.float32), |
| 322 | residx_rigidgroup_base_atom37_idx, |
| 323 | batch_dims=1) |
| 324 | gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) |
| 325 | |
| 326 | # Adapt backbone frame to old convention (mirror x-axis and z-axis). |
| 327 | rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) |
| 328 | rots[0, 0, 0] = -1 |
| 329 | rots[0, 2, 2] = -1 |
| 330 | gt_frames = gt_frames.compose_rotation( |
| 331 | geometry.Rot3Array.from_array(rots)) |
| 332 |
nothing calls this directly
no test coverage detected