Fetch Hopfield association matrix gathered by passing through the specified data. :param input: data to be passed through the Hopfield association :param stored_pattern_padding_mask: mask to be applied on stored patterns :param association_mask: mask to be applied o
(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
stored_pattern_padding_mask: Optional[Tensor] = None,
association_mask: Optional[Tensor] = None)
| 238 | return association_output |
| 239 | |
| 240 | def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], |
| 241 | stored_pattern_padding_mask: Optional[Tensor] = None, |
| 242 | association_mask: Optional[Tensor] = None) -> Tensor: |
| 243 | """ |
| 244 | Fetch Hopfield association matrix gathered by passing through the specified data. |
| 245 | |
| 246 | :param input: data to be passed through the Hopfield association |
| 247 | :param stored_pattern_padding_mask: mask to be applied on stored patterns |
| 248 | :param association_mask: mask to be applied on inner association matrix |
| 249 | :return: association matrix as computed by the Hopfield core module |
| 250 | """ |
| 251 | with torch.no_grad(): |
| 252 | return self._associate( |
| 253 | data=input, return_raw_associations=True, |
| 254 | stored_pattern_padding_mask=stored_pattern_padding_mask, |
| 255 | association_mask=association_mask)[2] |
| 256 | |
| 257 | def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], |
| 258 | stored_pattern_padding_mask: Optional[Tensor] = None, |
no test coverage detected