Apply Hopfield association module on specified data. :param data: data to be processed by Hopfield core module :param return_raw_associations: return raw association (softmax) values, unmodified :param return_projected_patterns: return pattern projection values, unm
(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]],
return_raw_associations: bool = False, return_projected_patterns: bool = False,
stored_pattern_padding_mask: Optional[Tensor] = None,
association_mask: Optional[Tensor] = None)
| 171 | return transposed_result[0] if len(transposed_result) == 1 else transposed_result |
| 172 | |
| 173 | def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], |
| 174 | return_raw_associations: bool = False, return_projected_patterns: bool = False, |
| 175 | stored_pattern_padding_mask: Optional[Tensor] = None, |
| 176 | association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]: |
| 177 | """ |
| 178 | Apply Hopfield association module on specified data. |
| 179 | |
| 180 | :param data: data to be processed by Hopfield core module |
| 181 | :param return_raw_associations: return raw association (softmax) values, unmodified |
| 182 | :param return_projected_patterns: return pattern projection values, unmodified |
| 183 | :param stored_pattern_padding_mask: mask to be applied on stored patterns |
| 184 | :param association_mask: mask to be applied on inner association matrix |
| 185 | :return: Hopfield-processed input data |
| 186 | """ |
| 187 | assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 3)), \ |
| 188 | r'either one tensor to be used as "stored pattern", "state pattern" and' \ |
| 189 | r' "pattern_projection" must be provided, or three separate ones.' |
| 190 | if type(data) == Tensor: |
| 191 | stored_pattern, state_pattern, pattern_projection = data, data, data |
| 192 | else: |
| 193 | stored_pattern, state_pattern, pattern_projection = data |
| 194 | |
| 195 | # Optionally transpose data. |
| 196 | stored_pattern, state_pattern, pattern_projection = self._maybe_transpose( |
| 197 | stored_pattern, state_pattern, pattern_projection) |
| 198 | |
| 199 | # Optionally apply stored pattern normalization. |
| 200 | if self.norm_stored_pattern is not None: |
| 201 | stored_pattern = self.norm_stored_pattern(input=stored_pattern.reshape( |
| 202 | shape=(-1, stored_pattern.shape[2]))).reshape(shape=stored_pattern.shape) |
| 203 | |
| 204 | # Optionally apply state pattern normalization. |
| 205 | if self.norm_state_pattern is not None: |
| 206 | state_pattern = self.norm_state_pattern(input=state_pattern.reshape( |
| 207 | shape=(-1, state_pattern.shape[2]))).reshape(shape=state_pattern.shape) |
| 208 | |
| 209 | # Optionally apply pattern projection normalization. |
| 210 | if self.norm_pattern_projection is not None: |
| 211 | pattern_projection = self.norm_pattern_projection(input=pattern_projection.reshape( |
| 212 | shape=(-1, pattern_projection.shape[2]))).reshape(shape=pattern_projection.shape) |
| 213 | |
| 214 | # Apply Hopfield association and optional activation function. |
| 215 | return self.association_core( |
| 216 | query=state_pattern, key=stored_pattern, value=pattern_projection, |
| 217 | key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask, |
| 218 | scaling=self.__scaling, update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps, |
| 219 | return_raw_associations=return_raw_associations, return_pattern_projections=return_projected_patterns) |
| 220 | |
| 221 | def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], |
| 222 | stored_pattern_padding_mask: Optional[Tensor] = None, |
no test coverage detected