(
coords: torch.Tensor,
feats: dict,
index_batch: int,
**args_rmsd,
)
| 237 | |
| 238 | |
| 239 | def minimum_lddt_symmetry_coords( |
| 240 | coords: torch.Tensor, |
| 241 | feats: dict, |
| 242 | index_batch: int, |
| 243 | **args_rmsd, |
| 244 | ): |
| 245 | all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords) |
| 246 | all_resolved_mask = ( |
| 247 | feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool) |
| 248 | ) |
| 249 | crop_to_all_atom_map = ( |
| 250 | feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long) |
| 251 | ) |
| 252 | chain_symmetries = feats["chain_symmetries"][index_batch] |
| 253 | amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch] |
| 254 | ligand_symmetries = feats["ligand_symmetries"][index_batch] |
| 255 | |
| 256 | dmat_predicted = torch.cdist( |
| 257 | coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)] |
| 258 | ) |
| 259 | |
| 260 | # Check best symmetry on chain swap |
| 261 | best_true_coords = None |
| 262 | best_lddt = 0 |
| 263 | for c in chain_symmetries: |
| 264 | true_all_coords = all_coords.clone() |
| 265 | true_all_resolved_mask = all_resolved_mask.clone() |
| 266 | for start1, end1, start2, end2, chainidx1, chainidx2 in c: |
| 267 | true_all_coords[:, start1:end1] = all_coords[:, start2:end2] |
| 268 | true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2] |
| 269 | true_coords = true_all_coords[:, crop_to_all_atom_map] |
| 270 | true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map] |
| 271 | dmat_true = torch.cdist(true_coords, true_coords) |
| 272 | pair_mask = ( |
| 273 | true_resolved_mask[:, None] |
| 274 | * true_resolved_mask[None, :] |
| 275 | * (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask) |
| 276 | ) |
| 277 | |
| 278 | lddt = lddt_dist( |
| 279 | dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False |
| 280 | )[0] |
| 281 | lddt = lddt.item() |
| 282 | |
| 283 | if lddt > best_lddt: |
| 284 | best_lddt = lddt |
| 285 | best_true_coords = true_coords |
| 286 | best_true_resolved_mask = true_resolved_mask |
| 287 | |
| 288 | # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment |
| 289 | true_coords = best_true_coords.clone() |
| 290 | true_resolved_mask = best_true_resolved_mask.clone() |
| 291 | for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries: |
| 292 | for c in symmetric_amino_or_lig: |
| 293 | # starting from greedy best, try to swap the atoms |
| 294 | new_true_coords = true_coords.clone() |
| 295 | new_true_resolved_mask = true_resolved_mask.clone() |
| 296 | indices = [] |
nothing calls this directly
no test coverage detected