MCPcopy
hub / github.com/ml-jku/hopfield-layers / get_association_matrix

Method get_association_matrix

hflayers/__init__.py:795–809  ·  view source on GitHub ↗

Fetch Hopfield association matrix used for lookup gathered by passing through the specified data. :param input: data to be passed through the Hopfield association :param stored_pattern_padding_mask: mask to be applied on stored patterns :param association_mask: mask

(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
                               association_mask: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

793 association_mask=association_mask)
794
795 def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
796 association_mask: Optional[Tensor] = None) -> Tensor:
797 """
798 Fetch Hopfield association matrix used for lookup gathered by passing through the specified data.
799
800 :param input: data to be passed through the Hopfield association
801 :param stored_pattern_padding_mask: mask to be applied on stored patterns
802 :param association_mask: mask to be applied on inner association matrix
803 :return: association matrix as computed by the Hopfield core module
804 """
805 with torch.no_grad():
806 return self.hopfield.get_association_matrix(
807 input=self._prepare_input(input=input),
808 stored_pattern_padding_mask=stored_pattern_padding_mask,
809 association_mask=association_mask)
810
811 def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
812 stored_pattern_padding_mask: Optional[Tensor] = None,

Callers

nothing calls this directly

Calls 2

_prepare_inputMethod · 0.95

Tested by

no test coverage detected