Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. :param input: data to be passed through the Hopfield association :param stored_pattern_padding_mask: mask to be applied on stored patterns :param association_mask: mas
(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
stored_pattern_padding_mask: Optional[Tensor] = None,
association_mask: Optional[Tensor] = None)
| 510 | association_mask=association_mask).flatten(start_dim=1) |
| 511 | |
| 512 | def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], |
| 513 | stored_pattern_padding_mask: Optional[Tensor] = None, |
| 514 | association_mask: Optional[Tensor] = None) -> Tensor: |
| 515 | """ |
| 516 | Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. |
| 517 | |
| 518 | :param input: data to be passed through the Hopfield association |
| 519 | :param stored_pattern_padding_mask: mask to be applied on stored patterns |
| 520 | :param association_mask: mask to be applied on inner association matrix |
| 521 | :return: association matrix as computed by the Hopfield core module |
| 522 | """ |
| 523 | with torch.no_grad(): |
| 524 | return self.hopfield.get_association_matrix( |
| 525 | input=self._prepare_input(input=input), |
| 526 | stored_pattern_padding_mask=stored_pattern_padding_mask, |
| 527 | association_mask=association_mask) |
| 528 | |
| 529 | def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]], |
| 530 | stored_pattern_padding_mask: Optional[Tensor] = None, |
nothing calls this directly
no test coverage detected