(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False)
| 422 | |
| 423 | |
| 424 | def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False): |
| 425 | # NOTE: the mask is a pairwise mask which should have the identity elements already masked out |
| 426 | # Compute mask over distances |
| 427 | dists_to_score = (dmat_true < cutoff).float() * mask |
| 428 | dist_l1 = torch.abs(dmat_true - dmat_predicted) |
| 429 | |
| 430 | score = 0.25 * ( |
| 431 | (dist_l1 < 0.5).float() |
| 432 | + (dist_l1 < 1.0).float() |
| 433 | + (dist_l1 < 2.0).float() |
| 434 | + (dist_l1 < 4.0).float() |
| 435 | ) |
| 436 | |
| 437 | # Normalize over the appropriate axes. |
| 438 | if per_atom: |
| 439 | mask_no_match = torch.sum(dists_to_score, dim=-1) != 0 |
| 440 | norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=-1)) |
| 441 | score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=-1)) |
| 442 | return score, mask_no_match.float() |
| 443 | else: |
| 444 | norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=(-2, -1))) |
| 445 | score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=(-2, -1))) |
| 446 | total = torch.sum(dists_to_score, dim=(-1, -2)) |
| 447 | return score, total |
| 448 | |
| 449 | |
| 450 | def express_coordinate_in_frame(atom_coords, frame_atom_a, frame_atom_b, frame_atom_c): |
no outgoing calls
no test coverage detected