Get the atom features. Parameters ---------- data : Tokenized The tokenized data. max_atoms : int, optional The maximum number of atoms. Returns ------- dict[str, Tensor] The atom features.
(
data: Tokenized,
atoms_per_window_queries: int = 32,
min_dist: float = 2.0,
max_dist: float = 22.0,
num_bins: int = 64,
max_atoms: Optional[int] = None,
max_tokens: Optional[int] = None,
)
| 666 | |
| 667 | |
| 668 | def process_atom_features( |
| 669 | data: Tokenized, |
| 670 | atoms_per_window_queries: int = 32, |
| 671 | min_dist: float = 2.0, |
| 672 | max_dist: float = 22.0, |
| 673 | num_bins: int = 64, |
| 674 | max_atoms: Optional[int] = None, |
| 675 | max_tokens: Optional[int] = None, |
| 676 | ) -> dict[str, Tensor]: |
| 677 | """Get the atom features. |
| 678 | |
| 679 | Parameters |
| 680 | ---------- |
| 681 | data : Tokenized |
| 682 | The tokenized data. |
| 683 | max_atoms : int, optional |
| 684 | The maximum number of atoms. |
| 685 | |
| 686 | Returns |
| 687 | ------- |
| 688 | dict[str, Tensor] |
| 689 | The atom features. |
| 690 | |
| 691 | """ |
| 692 | # Filter to tokens' atoms |
| 693 | atom_data = [] |
| 694 | ref_space_uid = [] |
| 695 | coord_data = [] |
| 696 | frame_data = [] |
| 697 | resolved_frame_data = [] |
| 698 | atom_to_token = [] |
| 699 | token_to_rep_atom = [] # index on cropped atom table |
| 700 | r_set_to_rep_atom = [] |
| 701 | disto_coords = [] |
| 702 | atom_idx = 0 |
| 703 | |
| 704 | chain_res_ids = {} |
| 705 | for token_id, token in enumerate(data.tokens): |
| 706 | # Get the chain residue ids |
| 707 | chain_idx, res_id = token["asym_id"], token["res_idx"] |
| 708 | chain = data.structure.chains[chain_idx] |
| 709 | |
| 710 | if (chain_idx, res_id) not in chain_res_ids: |
| 711 | new_idx = len(chain_res_ids) |
| 712 | chain_res_ids[(chain_idx, res_id)] = new_idx |
| 713 | else: |
| 714 | new_idx = chain_res_ids[(chain_idx, res_id)] |
| 715 | |
| 716 | # Map atoms to token indices |
| 717 | ref_space_uid.extend([new_idx] * token["atom_num"]) |
| 718 | atom_to_token.extend([token_id] * token["atom_num"]) |
| 719 | |
| 720 | # Add atom data |
| 721 | start = token["atom_idx"] |
| 722 | end = token["atom_idx"] + token["atom_num"] |
| 723 | token_atoms = data.structure.atoms[start:end] |
| 724 | |
| 725 | # Map token to representative atom |
no test coverage detected