(self, X)
| 998 | self.layer1.bias_hh_l0_reverse = nn.Parameter(torch.FloatTensor(b_b)) |
| 999 | |
| 1000 | def forward(self, X): |
| 1001 | # (batch, input_size, seq_len) -> (seq_len, batch, input_size) |
| 1002 | self.X = np.moveaxis(X, [0, 1, 2], [-2, -1, -3]) |
| 1003 | |
| 1004 | if not isinstance(self.X, torch.Tensor): |
| 1005 | self.X = torchify(self.X) |
| 1006 | |
| 1007 | self.X.retain_grad() |
| 1008 | |
| 1009 | # initial hidden state is 0 |
| 1010 | n_ex, n_in, n_timesteps = self.X.shape |
| 1011 | n_out, n_out = self.layer1.weight_hh_l0.shape |
| 1012 | |
| 1013 | # forward pass |
| 1014 | self.A, (At, Ct) = self.layer1(self.X) |
| 1015 | self.A.retain_grad() |
| 1016 | return self.A |
| 1017 | |
| 1018 | def extract_grads(self, X): |
| 1019 | self.forward(X) |
no test coverage detected