Get the token features. Parameters ---------- data : Tokenized The input data to the model. max_tokens : int The maximum number of tokens. Returns ------- dict[str, Tensor] The token features.
( # noqa: C901, PLR0915, PLR0912
data: Tokenized,
random: np.random.Generator,
max_tokens: Optional[int] = None,
binder_pocket_conditioned_prop: Optional[float] = 0.0,
contact_conditioned_prop: Optional[float] = 0.0,
binder_pocket_cutoff_min: Optional[float] = 4.0,
binder_pocket_cutoff_max: Optional[float] = 20.0,
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
only_ligand_binder_pocket: Optional[bool] = False,
only_pp_contact: Optional[bool] = False,
inference_pocket_constraints: Optional[
list[tuple[int, list[tuple[int, int]], float]]
] = False,
inference_contact_constraints: Optional[
list[tuple[tuple[int, int], tuple[int, int], float]]
] = False,
override_method: Optional[str] = None,
)
| 606 | |
| 607 | |
| 608 | def process_token_features( # noqa: C901, PLR0915, PLR0912 |
| 609 | data: Tokenized, |
| 610 | random: np.random.Generator, |
| 611 | max_tokens: Optional[int] = None, |
| 612 | binder_pocket_conditioned_prop: Optional[float] = 0.0, |
| 613 | contact_conditioned_prop: Optional[float] = 0.0, |
| 614 | binder_pocket_cutoff_min: Optional[float] = 4.0, |
| 615 | binder_pocket_cutoff_max: Optional[float] = 20.0, |
| 616 | binder_pocket_sampling_geometric_p: Optional[float] = 0.0, |
| 617 | only_ligand_binder_pocket: Optional[bool] = False, |
| 618 | only_pp_contact: Optional[bool] = False, |
| 619 | inference_pocket_constraints: Optional[ |
| 620 | list[tuple[int, list[tuple[int, int]], float]] |
| 621 | ] = False, |
| 622 | inference_contact_constraints: Optional[ |
| 623 | list[tuple[tuple[int, int], tuple[int, int], float]] |
| 624 | ] = False, |
| 625 | override_method: Optional[str] = None, |
| 626 | ) -> dict[str, Tensor]: |
| 627 | """Get the token features. |
| 628 | |
| 629 | Parameters |
| 630 | ---------- |
| 631 | data : Tokenized |
| 632 | The input data to the model. |
| 633 | max_tokens : int |
| 634 | The maximum number of tokens. |
| 635 | |
| 636 | Returns |
| 637 | ------- |
| 638 | dict[str, Tensor] |
| 639 | The token features. |
| 640 | |
| 641 | """ |
| 642 | # Token data |
| 643 | token_data = data.tokens |
| 644 | token_bonds = data.bonds |
| 645 | |
| 646 | # Token core features |
| 647 | token_index = torch.arange(len(token_data), dtype=torch.long) |
| 648 | residue_index = from_numpy(token_data["res_idx"]).long() |
| 649 | asym_id = from_numpy(token_data["asym_id"]).long() |
| 650 | entity_id = from_numpy(token_data["entity_id"]).long() |
| 651 | sym_id = from_numpy(token_data["sym_id"]).long() |
| 652 | mol_type = from_numpy(token_data["mol_type"]).long() |
| 653 | res_type = from_numpy(token_data["res_type"]).long() |
| 654 | res_type = one_hot(res_type, num_classes=const.num_tokens) |
| 655 | disto_center = from_numpy(token_data["disto_coords"]) |
| 656 | modified = from_numpy(token_data["modified"]).long() # float() |
| 657 | cyclic_period = from_numpy(token_data["cyclic_period"].copy()) |
| 658 | affinity_mask = from_numpy(token_data["affinity_mask"]).float() |
| 659 | |
| 660 | ## Conditioning features ## |
| 661 | method = ( |
| 662 | np.zeros(len(token_data)) |
| 663 | + const.method_types_ids[ |
| 664 | ( |
| 665 | "x-ray diffraction" |
no test coverage detected