Get the token features. Parameters ---------- data : Tokenized The tokenized data. max_tokens : int The maximum number of tokens. Returns ------- dict[str, Tensor] The token features.
(
data: Tokenized,
max_tokens: Optional[int] = None,
binder_pocket_conditioned_prop: Optional[float] = 0.0,
binder_pocket_cutoff: Optional[float] = 6.0,
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
only_ligand_binder_pocket: Optional[bool] = False,
inference_binder: Optional[list[int]] = None,
inference_pocket: Optional[list[tuple[int, int]]] = None,
)
| 480 | |
| 481 | |
| 482 | def process_token_features( |
| 483 | data: Tokenized, |
| 484 | max_tokens: Optional[int] = None, |
| 485 | binder_pocket_conditioned_prop: Optional[float] = 0.0, |
| 486 | binder_pocket_cutoff: Optional[float] = 6.0, |
| 487 | binder_pocket_sampling_geometric_p: Optional[float] = 0.0, |
| 488 | only_ligand_binder_pocket: Optional[bool] = False, |
| 489 | inference_binder: Optional[list[int]] = None, |
| 490 | inference_pocket: Optional[list[tuple[int, int]]] = None, |
| 491 | ) -> dict[str, Tensor]: |
| 492 | """Get the token features. |
| 493 | |
| 494 | Parameters |
| 495 | ---------- |
| 496 | data : Tokenized |
| 497 | The tokenized data. |
| 498 | max_tokens : int |
| 499 | The maximum number of tokens. |
| 500 | |
| 501 | Returns |
| 502 | ------- |
| 503 | dict[str, Tensor] |
| 504 | The token features. |
| 505 | |
| 506 | """ |
| 507 | # Token data |
| 508 | token_data = data.tokens |
| 509 | token_bonds = data.bonds |
| 510 | |
| 511 | # Token core features |
| 512 | token_index = torch.arange(len(token_data), dtype=torch.long) |
| 513 | residue_index = from_numpy(token_data["res_idx"].copy()).long() |
| 514 | asym_id = from_numpy(token_data["asym_id"].copy()).long() |
| 515 | entity_id = from_numpy(token_data["entity_id"].copy()).long() |
| 516 | sym_id = from_numpy(token_data["sym_id"].copy()).long() |
| 517 | mol_type = from_numpy(token_data["mol_type"].copy()).long() |
| 518 | res_type = from_numpy(token_data["res_type"].copy()).long() |
| 519 | res_type = one_hot(res_type, num_classes=const.num_tokens) |
| 520 | disto_center = from_numpy(token_data["disto_coords"].copy()) |
| 521 | |
| 522 | # Token mask features |
| 523 | pad_mask = torch.ones(len(token_data), dtype=torch.float) |
| 524 | resolved_mask = from_numpy(token_data["resolved_mask"].copy()).float() |
| 525 | disto_mask = from_numpy(token_data["disto_mask"].copy()).float() |
| 526 | cyclic_period = from_numpy(token_data["cyclic_period"].copy()) |
| 527 | |
| 528 | # Token bond features |
| 529 | if max_tokens is not None: |
| 530 | pad_len = max_tokens - len(token_data) |
| 531 | num_tokens = max_tokens if pad_len > 0 else len(token_data) |
| 532 | else: |
| 533 | num_tokens = len(token_data) |
| 534 | |
| 535 | tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)} |
| 536 | bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float) |
| 537 | for token_bond in token_bonds: |
| 538 | token_1 = tok_to_idx[token_bond["token_1"]] |
| 539 | token_2 = tok_to_idx[token_bond["token_2"]] |
no test coverage detected