MCPcopy
hub / github.com/google-deepmind/alphafold / atom37_to_frames

Function atom37_to_frames

alphafold/model/all_atom_multimer.py:275–371  ·  view source on GitHub ↗

Computes the frames for the up to 8 rigid groups for each residue.

(
    aatype: jnp.ndarray,  # (...)
    all_atom_positions: geometry.Vec3Array,  # (..., 37)
    all_atom_mask: jnp.ndarray,  # (..., 37)
)

Source from the content-addressed store, hash-verified

273
274
275def atom37_to_frames(
276 aatype: jnp.ndarray, # (...)
277 all_atom_positions: geometry.Vec3Array, # (..., 37)
278 all_atom_mask: jnp.ndarray, # (..., 37)
279) -> Dict[str, jnp.ndarray]:
280 """Computes the frames for the up to 8 rigid groups for each residue."""
281 # 0: 'backbone group',
282 # 1: 'pre-omega-group', (empty)
283 # 2: 'phi-group', (currently empty, because it defines only hydrogens)
284 # 3: 'psi-group',
285 # 4,5,6,7: 'chi1,2,3,4-group'
286 aatype_in_shape = aatype.shape
287
288 # If there is a batch axis, just flatten it away, and reshape everything
289 # back at the end of the function.
290 aatype = jnp.reshape(aatype, [-1])
291 all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]),
292 all_atom_positions)
293 all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37])
294
295 # Compute the gather indices for all residues in the chain.
296 # shape (N, 8, 3)
297 residx_rigidgroup_base_atom37_idx = utils.batched_gather(
298 RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype)
299
300 # Gather the base atom positions for each rigid group.
301 base_atom_pos = jax.tree_map(
302 lambda x: utils.batched_gather( # pylint: disable=g-long-lambda
303 x, residx_rigidgroup_base_atom37_idx, batch_dims=1),
304 all_atom_positions)
305
306 # Compute the Rigids.
307 point_on_neg_x_axis = base_atom_pos[:, :, 0]
308 origin = base_atom_pos[:, :, 1]
309 point_on_xy_plane = base_atom_pos[:, :, 2]
310 gt_rotation = geometry.Rot3Array.from_two_vectors(
311 origin - point_on_neg_x_axis, point_on_xy_plane - origin)
312
313 gt_frames = geometry.Rigid3Array(gt_rotation, origin)
314
315 # Compute a mask whether the group exists.
316 # (N, 8)
317 group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype)
318
319 # Compute a mask whether ground truth exists for the group
320 gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3)
321 all_atom_mask.astype(jnp.float32),
322 residx_rigidgroup_base_atom37_idx,
323 batch_dims=1)
324 gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8)
325
326 # Adapt backbone frame to old convention (mirror x-axis and z-axis).
327 rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
328 rots[0, 0, 0] = -1
329 rots[0, 2, 2] = -1
330 gt_frames = gt_frames.compose_rotation(
331 geometry.Rot3Array.from_array(rots))
332

Callers

nothing calls this directly

Calls 4

compose_rotationMethod · 0.95
from_two_vectorsMethod · 0.80
zerosMethod · 0.80
from_arrayMethod · 0.45

Tested by

no test coverage detected