MCPcopy
hub / github.com/ml-jku/hopfield-layers / _associate

Method _associate

hflayers/__init__.py:173–219  ·  view source on GitHub ↗

Apply Hopfield association module on specified data. :param data: data to be processed by Hopfield core module :param return_raw_associations: return raw association (softmax) values, unmodified :param return_projected_patterns: return pattern projection values, unm

(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
                   return_raw_associations: bool = False, return_projected_patterns: bool = False,
                   stored_pattern_padding_mask: Optional[Tensor] = None,
                   association_mask: Optional[Tensor] = None)

Source from the content-addressed store, hash-verified

171 return transposed_result[0] if len(transposed_result) == 1 else transposed_result
172
173 def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
174 return_raw_associations: bool = False, return_projected_patterns: bool = False,
175 stored_pattern_padding_mask: Optional[Tensor] = None,
176 association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]:
177 """
178 Apply Hopfield association module on specified data.
179
180 :param data: data to be processed by Hopfield core module
181 :param return_raw_associations: return raw association (softmax) values, unmodified
182 :param return_projected_patterns: return pattern projection values, unmodified
183 :param stored_pattern_padding_mask: mask to be applied on stored patterns
184 :param association_mask: mask to be applied on inner association matrix
185 :return: Hopfield-processed input data
186 """
187 assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \
188 r'either one tensor to be used as "stored pattern", "state pattern" and' \
189 r' "pattern_projection" must be provided, or three separate ones.'
190 if type(data) == Tensor:
191 stored_pattern, state_pattern, pattern_projection = data, data, data
192 else:
193 stored_pattern, state_pattern, pattern_projection = data
194
195 # Optionally transpose data.
196 stored_pattern, state_pattern, pattern_projection = self._maybe_transpose(
197 stored_pattern, state_pattern, pattern_projection)
198
199 # Optionally apply stored pattern normalization.
200 if self.norm_stored_pattern is not None:
201 stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape(
202 shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape)
203
204 # Optionally apply state pattern normalization.
205 if self.norm_state_pattern is not None:
206 state_pattern = self.norm_state_pattern(input=state_pattern.reshape(
207 shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape)
208
209 # Optionally apply pattern projection normalization.
210 if self.norm_pattern_projection is not None:
211 pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape(
212 shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape)
213
214 # Apply Hopfield association and optional activation function.
215 return self.association_core(
216 query=state_pattern, key=stored_pattern, value=pattern_projection,
217 key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask,
218 scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps,
219 return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns)
220
221 def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
222 stored_pattern_padding_mask: Optional[Tensor] = None,

Callers 3

forwardMethod · 0.95

Calls 1

_maybe_transposeMethod · 0.95

Tested by

no test coverage detected