MCPcopy
hub / github.com/jwohlwend/boltz / weighted_minimum_rmsd_single

Function weighted_minimum_rmsd_single

src/boltz/model/loss/validation.py:962–1025  ·  view source on GitHub ↗

Compute rmsd of the aligned atom coordinates. Parameters ---------- pred_atom_coords : torch.Tensor Predicted atom coordinates atom_coords: torch.Tensor Ground truth atom coordinates atom_mask : torch.Tensor Resolved atom mask atom_to_token : torch.Te

(
    pred_atom_coords,
    atom_coords,
    atom_mask,
    atom_to_token,
    mol_type,
    nucleotide_weight=5.0,
    ligand_weight=10.0,
)

Source from the content-addressed store, hash-verified

960
961
962def weighted_minimum_rmsd_single(
963 pred_atom_coords,
964 atom_coords,
965 atom_mask,
966 atom_to_token,
967 mol_type,
968 nucleotide_weight=5.0,
969 ligand_weight=10.0,
970):
971 """Compute rmsd of the aligned atom coordinates.
972
973 Parameters
974 ----------
975 pred_atom_coords : torch.Tensor
976 Predicted atom coordinates
977 atom_coords: torch.Tensor
978 Ground truth atom coordinates
979 atom_mask : torch.Tensor
980 Resolved atom mask
981 atom_to_token : torch.Tensor
982 Atom to token mapping
983 mol_type : torch.Tensor
984 Atom type
985
986 Returns
987 -------
988 Tensor
989 The rmsd
990 Tensor
991 The aligned coordinates
992 Tensor
993 The aligned weights
994
995 """
996 align_weights = atom_coords.new_ones(atom_coords.shape[:2])
997 atom_type = (
998 torch.bmm(atom_to_token.float(), mol_type.unsqueeze(-1).float())
999 .squeeze(-1)
1000 .long()
1001 )
1002
1003 align_weights = align_weights * (
1004 1
1005 + nucleotide_weight
1006 * (
1007 torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
1008 + torch.eq(atom_type, const.chain_type_ids["RNA"]).float()
1009 )
1010 + ligand_weight
1011 * torch.eq(atom_type, const.chain_type_ids["NONPOLYMER"]).float()
1012 )
1013
1014 with torch.no_grad():
1015 atom_coords_aligned_ground_truth = weighted_rigid_align(
1016 atom_coords, pred_atom_coords, align_weights, mask=atom_mask
1017 )
1018
1019 # weighted MSE loss of denoised atom positions

Callers 2

minimum_symmetry_coordsFunction · 0.90

Calls 1

weighted_rigid_alignFunction · 0.90

Tested by

no test coverage detected