Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. :param input: data to be passed through the Hopfield association :param stored_pattern_padding_mask: mask to be applied on stored patterns :param association_mask: mask
(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
association_mask: Optional[Tensor] = None)
| 793 | association_mask=association_mask) |
| 794 | |
| 795 | def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None, |
| 796 | association_mask: Optional[Tensor] = None) -> Tensor: |
| 797 | """ |
| 798 | Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. |
| 799 | |
| 800 | :param input: data to be passed through the Hopfield association |
| 801 | :param stored_pattern_padding_mask: mask to be applied on stored patterns |
| 802 | :param association_mask: mask to be applied on inner association matrix |
| 803 | :return: association matrix as computed by the Hopfield core module |
| 804 | """ |
| 805 | with torch.no_grad(): |
| 806 | return self.hopfield.get_association_matrix( |
| 807 | input=self._prepare_input(input=input), |
| 808 | stored_pattern_padding_mask=stored_pattern_padding_mask, |
| 809 | association_mask=association_mask) |
| 810 | |
| 811 | def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], |
| 812 | stored_pattern_padding_mask: Optional[Tensor] = None, |
nothing calls this directly
no test coverage detected