Compute rmsd of the aligned atom coordinates. Parameters ---------- pred_atom_coords : torch.Tensor Predicted atom coordinates atom_coords: torch.Tensor Ground truth atom coordinates atom_mask : torch.Tensor Resolved atom mask atom_to_token : torch.Te
(
pred_atom_coords,
atom_coords,
atom_mask,
atom_to_token,
mol_type,
nucleotide_weight=5.0,
ligand_weight=10.0,
)
| 960 | |
| 961 | |
| 962 | def weighted_minimum_rmsd_single( |
| 963 | pred_atom_coords, |
| 964 | atom_coords, |
| 965 | atom_mask, |
| 966 | atom_to_token, |
| 967 | mol_type, |
| 968 | nucleotide_weight=5.0, |
| 969 | ligand_weight=10.0, |
| 970 | ): |
| 971 | """Compute rmsd of the aligned atom coordinates. |
| 972 | |
| 973 | Parameters |
| 974 | ---------- |
| 975 | pred_atom_coords : torch.Tensor |
| 976 | Predicted atom coordinates |
| 977 | atom_coords: torch.Tensor |
| 978 | Ground truth atom coordinates |
| 979 | atom_mask : torch.Tensor |
| 980 | Resolved atom mask |
| 981 | atom_to_token : torch.Tensor |
| 982 | Atom to token mapping |
| 983 | mol_type : torch.Tensor |
| 984 | Atom type |
| 985 | |
| 986 | Returns |
| 987 | ------- |
| 988 | Tensor |
| 989 | The rmsd |
| 990 | Tensor |
| 991 | The aligned coordinates |
| 992 | Tensor |
| 993 | The aligned weights |
| 994 | |
| 995 | """ |
| 996 | align_weights = atom_coords.new_ones(atom_coords.shape[:2]) |
| 997 | atom_type = ( |
| 998 | torch.bmm(atom_to_token.float(), mol_type.unsqueeze(-1).float()) |
| 999 | .squeeze(-1) |
| 1000 | .long() |
| 1001 | ) |
| 1002 | |
| 1003 | align_weights = align_weights * ( |
| 1004 | 1 |
| 1005 | + nucleotide_weight |
| 1006 | * ( |
| 1007 | torch.eq(atom_type, const.chain_type_ids["DNA"]).float() |
| 1008 | + torch.eq(atom_type, const.chain_type_ids["RNA"]).float() |
| 1009 | ) |
| 1010 | + ligand_weight |
| 1011 | * torch.eq(atom_type, const.chain_type_ids["NONPOLYMER"]).float() |
| 1012 | ) |
| 1013 | |
| 1014 | with torch.no_grad(): |
| 1015 | atom_coords_aligned_ground_truth = weighted_rigid_align( |
| 1016 | atom_coords, pred_atom_coords, align_weights, mask=atom_mask |
| 1017 | ) |
| 1018 | |
| 1019 | # weighted MSE loss of denoised atom positions |
no test coverage detected