MCPcopy
hub / github.com/jwohlwend/boltz / compute

Method compute

src/boltz/model/potentials/potentials.py:24–89  ·  view source on GitHub ↗
(self, coords, feats, parameters)

Source from the content-addressed store, hash-verified

22 self.parameters = parameters
23
24 def compute(self, coords, feats, parameters):
25 index, args, com_args, ref_args, operator_args = self.compute_args(
26 feats, parameters
27 )
28
29 if index.shape[1] == 0:
30 return torch.zeros(coords.shape[:-2], device=coords.device)
31
32 if com_args is not None:
33 com_index, atom_pad_mask = com_args
34 unpad_com_index = com_index[atom_pad_mask]
35 unpad_coords = coords[..., atom_pad_mask, :]
36 coords = torch.zeros(
37 (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3),
38 device=coords.device,
39 ).scatter_reduce(
40 -2,
41 unpad_com_index.unsqueeze(-1).expand_as(unpad_coords),
42 unpad_coords,
43 "mean",
44 )
45 else:
46 com_index, atom_pad_mask = None, None
47
48 if ref_args is not None:
49 ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args
50 coords = coords[..., ref_atom_index, :]
51 else:
52 ref_coords, ref_mask, ref_atom_index, ref_token_index = (
53 None,
54 None,
55 None,
56 None,
57 )
58
59 if operator_args is not None:
60 negation_mask, union_index = operator_args
61 else:
62 negation_mask, union_index = None, None
63
64 value = self.compute_variable(
65 coords,
66 index,
67 ref_coords=ref_coords,
68 ref_mask=ref_mask,
69 compute_gradient=False,
70 )
71 energy = self.compute_function(
72 value, *args, negation_mask=negation_mask, compute_derivative=False
73 )
74
75 if union_index is not None:
76 neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy)
77 Z = torch.zeros(
78 (*energy.shape[:-1], union_index.max() + 1), device=union_index.device
79 ).scatter_reduce(
80 -1,
81 union_index.expand_as(neg_exp_energy),

Callers 4

sampleMethod · 0.45
sampleMethod · 0.45
compute_parametersMethod · 0.45

Calls 3

compute_argsMethod · 0.95
compute_variableMethod · 0.95
compute_functionMethod · 0.95

Tested by

no test coverage detected