MCPcopy
hub / github.com/jwohlwend/boltz / process_token_features

Function process_token_features

src/boltz/data/feature/featurizer.py:482–665  ·  view source on GitHub ↗

Get the token features. Parameters ---------- data : Tokenized The tokenized data. max_tokens : int The maximum number of tokens. Returns ------- dict[str, Tensor] The token features.

(
    data: Tokenized,
    max_tokens: Optional[int] = None,
    binder_pocket_conditioned_prop: Optional[float] = 0.0,
    binder_pocket_cutoff: Optional[float] = 6.0,
    binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
    only_ligand_binder_pocket: Optional[bool] = False,
    inference_binder: Optional[list[int]] = None,
    inference_pocket: Optional[list[tuple[int, int]]] = None,
)

Source from the content-addressed store, hash-verified

480
481
482def process_token_features(
483 data: Tokenized,
484 max_tokens: Optional[int] = None,
485 binder_pocket_conditioned_prop: Optional[float] = 0.0,
486 binder_pocket_cutoff: Optional[float] = 6.0,
487 binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
488 only_ligand_binder_pocket: Optional[bool] = False,
489 inference_binder: Optional[list[int]] = None,
490 inference_pocket: Optional[list[tuple[int, int]]] = None,
491) -> dict[str, Tensor]:
492 """Get the token features.
493
494 Parameters
495 ----------
496 data : Tokenized
497 The tokenized data.
498 max_tokens : int
499 The maximum number of tokens.
500
501 Returns
502 -------
503 dict[str, Tensor]
504 The token features.
505
506 """
507 # Token data
508 token_data = data.tokens
509 token_bonds = data.bonds
510
511 # Token core features
512 token_index = torch.arange(len(token_data), dtype=torch.long)
513 residue_index = from_numpy(token_data["res_idx"].copy()).long()
514 asym_id = from_numpy(token_data["asym_id"].copy()).long()
515 entity_id = from_numpy(token_data["entity_id"].copy()).long()
516 sym_id = from_numpy(token_data["sym_id"].copy()).long()
517 mol_type = from_numpy(token_data["mol_type"].copy()).long()
518 res_type = from_numpy(token_data["res_type"].copy()).long()
519 res_type = one_hot(res_type, num_classes=const.num_tokens)
520 disto_center = from_numpy(token_data["disto_coords"].copy())
521
522 # Token mask features
523 pad_mask = torch.ones(len(token_data), dtype=torch.float)
524 resolved_mask = from_numpy(token_data["resolved_mask"].copy()).float()
525 disto_mask = from_numpy(token_data["disto_mask"].copy()).float()
526 cyclic_period = from_numpy(token_data["cyclic_period"].copy())
527
528 # Token bond features
529 if max_tokens is not None:
530 pad_len = max_tokens - len(token_data)
531 num_tokens = max_tokens if pad_len > 0 else len(token_data)
532 else:
533 num_tokens = len(token_data)
534
535 tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
536 bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
537 for token_bond in token_bonds:
538 token_1 = tok_to_idx[token_bond["token_1"]]
539 token_2 = tok_to_idx[token_bond["token_2"]]

Callers 1

processMethod · 0.70

Calls 2

pad_dimFunction · 0.90
select_subset_from_maskFunction · 0.70

Tested by

no test coverage detected