(self,
value: Mapping[str, Any],
batch: Mapping[str, Any]
)
| 625 | return no_loss_ret |
| 626 | |
| 627 | def loss(self, |
| 628 | value: Mapping[str, Any], |
| 629 | batch: Mapping[str, Any] |
| 630 | ) -> Dict[str, Any]: |
| 631 | |
| 632 | raise NotImplementedError( |
| 633 | 'This function should be called on a batch with reordered chains (see ' |
| 634 | 'Evans et al (2021) Section 7.3. Multi-Chain Permutation Alignment.') |
| 635 | |
| 636 | ret = {'loss': 0.} |
| 637 | |
| 638 | ret['metrics'] = {} |
| 639 | |
| 640 | aatype = batch['aatype'] |
| 641 | all_atom_positions = batch['all_atom_positions'] |
| 642 | all_atom_positions = geometry.Vec3Array.from_array(all_atom_positions) |
| 643 | all_atom_mask = batch['all_atom_mask'] |
| 644 | seq_mask = batch['seq_mask'] |
| 645 | residue_index = batch['residue_index'] |
| 646 | |
| 647 | gt_rigid, gt_affine_mask = make_backbone_affine(all_atom_positions, |
| 648 | all_atom_mask, |
| 649 | aatype) |
| 650 | |
| 651 | chi_angles, chi_mask = all_atom_multimer.compute_chi_angles( |
| 652 | all_atom_positions, all_atom_mask, aatype) |
| 653 | |
| 654 | pred_mask = all_atom_multimer.get_atom14_mask(aatype) |
| 655 | pred_mask *= seq_mask[:, None] |
| 656 | pred_positions = value['final_atom14_positions'] |
| 657 | pred_positions = geometry.Vec3Array.from_array(pred_positions) |
| 658 | |
| 659 | gt_positions, gt_mask, alt_naming_is_better = compute_atom14_gt( |
| 660 | aatype, all_atom_positions, all_atom_mask, pred_positions) |
| 661 | |
| 662 | violations = find_structural_violations( |
| 663 | aatype=aatype, |
| 664 | residue_index=residue_index, |
| 665 | mask=pred_mask, |
| 666 | pred_positions=pred_positions, |
| 667 | config=self.config, |
| 668 | asym_id=batch['asym_id']) |
| 669 | |
| 670 | sidechains = value['sidechains'] |
| 671 | |
| 672 | gt_chi_angles = get_renamed_chi_angles(aatype, chi_angles, |
| 673 | alt_naming_is_better) |
| 674 | |
| 675 | # Several violation metrics: |
| 676 | violation_metrics = compute_violation_metrics( |
| 677 | residue_index=residue_index, |
| 678 | mask=pred_mask, |
| 679 | seq_mask=seq_mask, |
| 680 | pred_positions=pred_positions, |
| 681 | violations=violations) |
| 682 | ret['metrics'].update(violation_metrics) |
| 683 | |
| 684 | target_rigid = geometry.Rigid3Array.from_array(value['traj']) |
nothing calls this directly
no test coverage detected