(self, coords, feats, parameters)
| 22 | self.parameters = parameters |
| 23 | |
| 24 | def compute(self, coords, feats, parameters): |
| 25 | index, args, com_args, ref_args, operator_args = self.compute_args( |
| 26 | feats, parameters |
| 27 | ) |
| 28 | |
| 29 | if index.shape[1] == 0: |
| 30 | return torch.zeros(coords.shape[:-2], device=coords.device) |
| 31 | |
| 32 | if com_args is not None: |
| 33 | com_index, atom_pad_mask = com_args |
| 34 | unpad_com_index = com_index[atom_pad_mask] |
| 35 | unpad_coords = coords[..., atom_pad_mask, :] |
| 36 | coords = torch.zeros( |
| 37 | (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), |
| 38 | device=coords.device, |
| 39 | ).scatter_reduce( |
| 40 | -2, |
| 41 | unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), |
| 42 | unpad_coords, |
| 43 | "mean", |
| 44 | ) |
| 45 | else: |
| 46 | com_index, atom_pad_mask = None, None |
| 47 | |
| 48 | if ref_args is not None: |
| 49 | ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args |
| 50 | coords = coords[..., ref_atom_index, :] |
| 51 | else: |
| 52 | ref_coords, ref_mask, ref_atom_index, ref_token_index = ( |
| 53 | None, |
| 54 | None, |
| 55 | None, |
| 56 | None, |
| 57 | ) |
| 58 | |
| 59 | if operator_args is not None: |
| 60 | negation_mask, union_index = operator_args |
| 61 | else: |
| 62 | negation_mask, union_index = None, None |
| 63 | |
| 64 | value = self.compute_variable( |
| 65 | coords, |
| 66 | index, |
| 67 | ref_coords=ref_coords, |
| 68 | ref_mask=ref_mask, |
| 69 | compute_gradient=False, |
| 70 | ) |
| 71 | energy = self.compute_function( |
| 72 | value, *args, negation_mask=negation_mask, compute_derivative=False |
| 73 | ) |
| 74 | |
| 75 | if union_index is not None: |
| 76 | neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) |
| 77 | Z = torch.zeros( |
| 78 | (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| 79 | ).scatter_reduce( |
| 80 | -1, |
| 81 | union_index.expand_as(neg_exp_energy), |
no test coverage detected