MCPcopy
hub / github.com/google-deepmind/alphafold / loss

Method loss

alphafold/model/folding_multimer.py:627–749  ·  view source on GitHub ↗
(self,
           value: Mapping[str, Any],
           batch: Mapping[str, Any]
           )

Source from the content-addressed store, hash-verified

625 return no_loss_ret
626
627 def loss(self,
628 value: Mapping[str, Any],
629 batch: Mapping[str, Any]
630 ) -> Dict[str, Any]:
631
632 raise NotImplementedError(
633 'This function should be called on a batch with reordered chains (see '
634 'Evans et al (2021) Section 7.3. Multi-Chain Permutation Alignment.')
635
636 ret = {'loss': 0.}
637
638 ret['metrics'] = {}
639
640 aatype = batch['aatype']
641 all_atom_positions = batch['all_atom_positions']
642 all_atom_positions = geometry.Vec3Array.from_array(all_atom_positions)
643 all_atom_mask = batch['all_atom_mask']
644 seq_mask = batch['seq_mask']
645 residue_index = batch['residue_index']
646
647 gt_rigid, gt_affine_mask = make_backbone_affine(all_atom_positions,
648 all_atom_mask,
649 aatype)
650
651 chi_angles, chi_mask = all_atom_multimer.compute_chi_angles(
652 all_atom_positions, all_atom_mask, aatype)
653
654 pred_mask = all_atom_multimer.get_atom14_mask(aatype)
655 pred_mask *= seq_mask[:, None]
656 pred_positions = value['final_atom14_positions']
657 pred_positions = geometry.Vec3Array.from_array(pred_positions)
658
659 gt_positions, gt_mask, alt_naming_is_better = compute_atom14_gt(
660 aatype, all_atom_positions, all_atom_mask, pred_positions)
661
662 violations = find_structural_violations(
663 aatype=aatype,
664 residue_index=residue_index,
665 mask=pred_mask,
666 pred_positions=pred_positions,
667 config=self.config,
668 asym_id=batch['asym_id'])
669
670 sidechains = value['sidechains']
671
672 gt_chi_angles = get_renamed_chi_angles(aatype, chi_angles,
673 alt_naming_is_better)
674
675 # Several violation metrics:
676 violation_metrics = compute_violation_metrics(
677 residue_index=residue_index,
678 mask=pred_mask,
679 seq_mask=seq_mask,
680 pred_positions=pred_positions,
681 violations=violations)
682 ret['metrics'].update(violation_metrics)
683
684 target_rigid = geometry.Rigid3Array.from_array(value['traj'])

Callers

nothing calls this directly

Calls 11

make_backbone_affineFunction · 0.85
compute_atom14_gtFunction · 0.85
get_renamed_chi_anglesFunction · 0.85
compute_framesFunction · 0.85
backbone_lossFunction · 0.70
sidechain_lossFunction · 0.70
supervised_chi_lossFunction · 0.70
from_arrayMethod · 0.45

Tested by

no test coverage detected