MCPcopy
hub / github.com/ddbourgin/numpy-ml / extract_grads

Method extract_grads

numpy_ml/tests/nn_torch_models.py:1018–1100  ·  view source on GitHub ↗
(self, X)

Source from the content-addressed store, hash-verified

1016 return self.A
1017
1018 def extract_grads(self, X):
1019 self.forward(X)
1020 self.loss = self.A.sum()
1021 self.loss.backward()
1022
1023 # forward
1024 w_ii, w_if, w_ic, w_io = self.layer1.weight_ih_l0.chunk(4, 0)
1025 w_hi, w_hf, w_hc, w_ho = self.layer1.weight_hh_l0.chunk(4, 0)
1026 bu_f, bf_f, bc_f, bo_f = self.layer1.bias_ih_l0.chunk(4, 0)
1027
1028 Wu_f = torch.cat([torch.t(w_hi), torch.t(w_ii)], dim=0)
1029 Wf_f = torch.cat([torch.t(w_hf), torch.t(w_if)], dim=0)
1030 Wc_f = torch.cat([torch.t(w_hc), torch.t(w_ic)], dim=0)
1031 Wo_f = torch.cat([torch.t(w_ho), torch.t(w_io)], dim=0)
1032
1033 dw_ii, dw_if, dw_ic, dw_io = self.layer1.weight_ih_l0.grad.chunk(4, 0)
1034 dw_hi, dw_hf, dw_hc, dw_ho = self.layer1.weight_hh_l0.grad.chunk(4, 0)
1035 dbu_f, dbf_f, dbc_f, dbo_f = self.layer1.bias_ih_l0.grad.chunk(4, 0)
1036
1037 dWu_f = torch.cat([torch.t(dw_hi), torch.t(dw_ii)], dim=0)
1038 dWf_f = torch.cat([torch.t(dw_hf), torch.t(dw_if)], dim=0)
1039 dWc_f = torch.cat([torch.t(dw_hc), torch.t(dw_ic)], dim=0)
1040 dWo_f = torch.cat([torch.t(dw_ho), torch.t(dw_io)], dim=0)
1041
1042 # backward
1043 w_ii, w_if, w_ic, w_io = self.layer1.weight_ih_l0_reverse.chunk(4, 0)
1044 w_hi, w_hf, w_hc, w_ho = self.layer1.weight_hh_l0_reverse.chunk(4, 0)
1045 bu_b, bf_b, bc_b, bo_b = self.layer1.bias_ih_l0_reverse.chunk(4, 0)
1046
1047 Wu_b = torch.cat([torch.t(w_hi), torch.t(w_ii)], dim=0)
1048 Wf_b = torch.cat([torch.t(w_hf), torch.t(w_if)], dim=0)
1049 Wc_b = torch.cat([torch.t(w_hc), torch.t(w_ic)], dim=0)
1050 Wo_b = torch.cat([torch.t(w_ho), torch.t(w_io)], dim=0)
1051
1052 dw_ii, dw_if, dw_ic, dw_io = self.layer1.weight_ih_l0_reverse.grad.chunk(4, 0)
1053 dw_hi, dw_hf, dw_hc, dw_ho = self.layer1.weight_hh_l0_reverse.grad.chunk(4, 0)
1054 dbu_b, dbf_b, dbc_b, dbo_b = self.layer1.bias_ih_l0_reverse.grad.chunk(4, 0)
1055
1056 dWu_b = torch.cat([torch.t(dw_hi), torch.t(dw_ii)], dim=0)
1057 dWf_b = torch.cat([torch.t(dw_hf), torch.t(dw_if)], dim=0)
1058 dWc_b = torch.cat([torch.t(dw_hc), torch.t(dw_ic)], dim=0)
1059 dWo_b = torch.cat([torch.t(dw_ho), torch.t(dw_io)], dim=0)
1060
1061 orig, X_swap = [0, 1, 2], [-1, -3, -2]
1062 grads = {
1063 "X": np.moveaxis(self.X.detach().numpy(), orig, X_swap),
1064 "Wu_f": Wu_f.detach().numpy(),
1065 "Wf_f": Wf_f.detach().numpy(),
1066 "Wc_f": Wc_f.detach().numpy(),
1067 "Wo_f": Wo_f.detach().numpy(),
1068 "bu_f": bu_f.detach().numpy().reshape(-1, 1),
1069 "bf_f": bf_f.detach().numpy().reshape(-1, 1),
1070 "bc_f": bc_f.detach().numpy().reshape(-1, 1),
1071 "bo_f": bo_f.detach().numpy().reshape(-1, 1),
1072 "Wu_b": Wu_b.detach().numpy(),
1073 "Wf_b": Wf_b.detach().numpy(),
1074 "Wc_b": Wc_b.detach().numpy(),
1075 "Wo_b": Wo_b.detach().numpy(),

Callers 1

test_BidirectionalLSTMFunction · 0.95

Calls 2

forwardMethod · 0.95
backwardMethod · 0.45

Tested by 1

test_BidirectionalLSTMFunction · 0.76