| 205 | self.to(device) |
| 206 | |
| 207 | def forward(self, inputs): |
| 208 | if len(inputs.shape) != 3: |
| 209 | raise ValueError( |
| 210 | "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape))) |
| 211 | batch_size = inputs.shape[0] |
| 212 | dim = inputs.shape[-1] |
| 213 | hidden_nn_layers = [inputs] |
| 214 | final_result = [] |
| 215 | |
| 216 | for i, size in enumerate(self.layer_size): |
| 217 | # x^(k-1) * x^0 |
| 218 | x = torch.einsum( |
| 219 | 'bhd,bmd->bhmd', hidden_nn_layers[-1], hidden_nn_layers[0]) |
| 220 | # x.shape = (batch_size , hi * m, dim) |
| 221 | x = x.reshape( |
| 222 | batch_size, hidden_nn_layers[-1].shape[1] * hidden_nn_layers[0].shape[1], dim) |
| 223 | # x.shape = (batch_size , hi, dim) |
| 224 | x = self.conv1ds[i](x) |
| 225 | |
| 226 | if self.activation is None or self.activation == 'linear': |
| 227 | curr_out = x |
| 228 | else: |
| 229 | curr_out = self.activation(x) |
| 230 | |
| 231 | if self.split_half: |
| 232 | if i != len(self.layer_size) - 1: |
| 233 | next_hidden, direct_connect = torch.split( |
| 234 | curr_out, 2 * [size // 2], 1) |
| 235 | else: |
| 236 | direct_connect = curr_out |
| 237 | next_hidden = 0 |
| 238 | else: |
| 239 | direct_connect = curr_out |
| 240 | next_hidden = curr_out |
| 241 | |
| 242 | final_result.append(direct_connect) |
| 243 | hidden_nn_layers.append(next_hidden) |
| 244 | |
| 245 | result = torch.cat(final_result, dim=1) |
| 246 | result = torch.sum(result, -1) |
| 247 | |
| 248 | return result |
| 249 | |
| 250 | |
| 251 | class AFMLayer(nn.Module): |