MCPcopy
hub / github.com/ml-jku/hopfield-layers / get_association_matrix

Method get_association_matrix

hflayers/__init__.py:512–527  ·  view source on GitHub ↗

Fetch Hopfield association matrix used for pooling gathered by passing through the specified data. :param input: data to be passed through the Hopfield association :param stored_pattern_padding_mask: mask to be applied on stored patterns :param association_mask: mas

(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
                               stored_pattern_padding_mask: Optional[Tensor] = None,
                               association_mask: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

510 association_mask=association_mask).flatten(start_dim=1)
511
512 def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
513 stored_pattern_padding_mask: Optional[Tensor] = None,
514 association_mask: Optional[Tensor] = None) -> Tensor:
515 """
516 Fetch Hopfield association matrix used for pooling gathered by passing through the specified data.
517
518 :param input: data to be passed through the Hopfield association
519 :param stored_pattern_padding_mask: mask to be applied on stored patterns
520 :param association_mask: mask to be applied on inner association matrix
521 :return: association matrix as computed by the Hopfield core module
522 """
523 with torch.no_grad():
524 return self.hopfield.get_association_matrix(
525 input=self._prepare_input(input=input),
526 stored_pattern_padding_mask=stored_pattern_padding_mask,
527 association_mask=association_mask)
528
529 def get_projected_pattern_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
530 stored_pattern_padding_mask: Optional[Tensor] = None,

Callers

nothing calls this directly

Calls 2

_prepare_inputMethod · 0.95

Tested by

no test coverage detected