MCPcopy
hub / github.com/jwohlwend/boltz / compute_frames_nonpolymer

Function compute_frames_nonpolymer

src/boltz/data/feature/featurizer.py:34–113  ·  view source on GitHub ↗

Get the frames for non-polymer tokens. Parameters ---------- data : Tokenized The tokenized data. frame_data : list The frame data. resolved_frame_data : list The resolved frame data. Returns ------- tuple[list, list] The frame data a

(
    data: Tokenized,
    coords,
    resolved_mask,
    atom_to_token,
    frame_data: list,
    resolved_frame_data: list,
)

Source from the content-addressed store, hash-verified

32
33
34def compute_frames_nonpolymer(
35 data: Tokenized,
36 coords,
37 resolved_mask,
38 atom_to_token,
39 frame_data: list,
40 resolved_frame_data: list,
41) -> tuple[list, list]:
42 """Get the frames for non-polymer tokens.
43
44 Parameters
45 ----------
46 data : Tokenized
47 The tokenized data.
48 frame_data : list
49 The frame data.
50 resolved_frame_data : list
51 The resolved frame data.
52
53 Returns
54 -------
55 tuple[list, list]
56 The frame data and resolved frame data.
57
58 """
59 frame_data = np.array(frame_data)
60 resolved_frame_data = np.array(resolved_frame_data)
61 asym_id_token = data.tokens["asym_id"]
62 asym_id_atom = data.tokens["asym_id"][atom_to_token]
63 token_idx = 0
64 atom_idx = 0
65 for id in np.unique(data.tokens["asym_id"]):
66 mask_chain_token = asym_id_token == id
67 mask_chain_atom = asym_id_atom == id
68 num_tokens = mask_chain_token.sum()
69 num_atoms = mask_chain_atom.sum()
70 if (
71 data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
72 or num_atoms < 3
73 ):
74 token_idx += num_tokens
75 atom_idx += num_atoms
76 continue
77 dist_mat = (
78 (
79 coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
80 - coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
81 )
82 ** 2
83 ).sum(-1) ** 0.5
84 resolved_pair = 1 - (
85 resolved_mask[mask_chain_atom][None, :]
86 * resolved_mask[mask_chain_atom][:, None]
87 ).astype(np.float32)
88 resolved_pair[resolved_pair == 1] = math.inf
89 indices = np.argsort(dist_mat + resolved_pair, axis=1)
90 frames = (
91 np.concatenate(

Callers 1

process_atom_featuresFunction · 0.70

Calls 1

compute_collinear_maskFunction · 0.70

Tested by

no test coverage detected