MCPcopy
hub / github.com/jwohlwend/boltz / lddt_dist

Function lddt_dist

src/boltz/model/loss/confidence.py:424–447  ·  view source on GitHub ↗
(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False)

Source from the content-addressed store, hash-verified

422
423
424def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False):
425 # NOTE: the mask is a pairwise mask which should have the identity elements already masked out
426 # Compute mask over distances
427 dists_to_score = (dmat_true < cutoff).float() * mask
428 dist_l1 = torch.abs(dmat_true - dmat_predicted)
429
430 score = 0.25 * (
431 (dist_l1 < 0.5).float()
432 + (dist_l1 < 1.0).float()
433 + (dist_l1 < 2.0).float()
434 + (dist_l1 < 4.0).float()
435 )
436
437 # Normalize over the appropriate axes.
438 if per_atom:
439 mask_no_match = torch.sum(dists_to_score, dim=-1) != 0
440 norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=-1))
441 score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=-1))
442 return score, mask_no_match.float()
443 else:
444 norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=(-2, -1)))
445 score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=(-2, -1)))
446 total = torch.sum(dists_to_score, dim=(-1, -2))
447 return score, total
448
449
450def express_coordinate_in_frame(atom_coords, frame_atom_a, frame_atom_b, frame_atom_c):

Callers 6

factored_lddt_lossFunction · 0.90
compute_plddt_maeFunction · 0.90
plddt_lossFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected