| 919 | |
| 920 | class TorchBidirectionalLSTM(nn.Module): |
| 921 | def __init__(self, n_in, n_out, params, **kwargs): |
| 922 | super(TorchBidirectionalLSTM, self).__init__() |
| 923 | |
| 924 | self.layer1 = nn.LSTM( |
| 925 | input_size=n_in, |
| 926 | hidden_size=n_out, |
| 927 | num_layers=1, |
| 928 | bidirectional=True, |
| 929 | bias=True, |
| 930 | ) |
| 931 | |
| 932 | Wiu = params["components"]["cell_fwd"]["Wu"][n_out:, :].T |
| 933 | Wif = params["components"]["cell_fwd"]["Wf"][n_out:, :].T |
| 934 | Wic = params["components"]["cell_fwd"]["Wc"][n_out:, :].T |
| 935 | Wio = params["components"]["cell_fwd"]["Wo"][n_out:, :].T |
| 936 | W_ih_f = np.vstack([Wiu, Wif, Wic, Wio]) |
| 937 | |
| 938 | Whu = params["components"]["cell_fwd"]["Wu"][:n_out, :].T |
| 939 | Whf = params["components"]["cell_fwd"]["Wf"][:n_out, :].T |
| 940 | Whc = params["components"]["cell_fwd"]["Wc"][:n_out, :].T |
| 941 | Who = params["components"]["cell_fwd"]["Wo"][:n_out, :].T |
| 942 | W_hh_f = np.vstack([Whu, Whf, Whc, Who]) |
| 943 | |
| 944 | assert self.layer1.weight_ih_l0.shape == W_ih_f.shape |
| 945 | assert self.layer1.weight_hh_l0.shape == W_hh_f.shape |
| 946 | |
| 947 | self.layer1.weight_ih_l0 = nn.Parameter(torch.FloatTensor(W_ih_f)) |
| 948 | self.layer1.weight_hh_l0 = nn.Parameter(torch.FloatTensor(W_hh_f)) |
| 949 | |
| 950 | Wiu = params["components"]["cell_bwd"]["Wu"][n_out:, :].T |
| 951 | Wif = params["components"]["cell_bwd"]["Wf"][n_out:, :].T |
| 952 | Wic = params["components"]["cell_bwd"]["Wc"][n_out:, :].T |
| 953 | Wio = params["components"]["cell_bwd"]["Wo"][n_out:, :].T |
| 954 | W_ih_b = np.vstack([Wiu, Wif, Wic, Wio]) |
| 955 | |
| 956 | Whu = params["components"]["cell_bwd"]["Wu"][:n_out, :].T |
| 957 | Whf = params["components"]["cell_bwd"]["Wf"][:n_out, :].T |
| 958 | Whc = params["components"]["cell_bwd"]["Wc"][:n_out, :].T |
| 959 | Who = params["components"]["cell_bwd"]["Wo"][:n_out, :].T |
| 960 | W_hh_b = np.vstack([Whu, Whf, Whc, Who]) |
| 961 | |
| 962 | assert self.layer1.weight_ih_l0_reverse.shape == W_ih_b.shape |
| 963 | assert self.layer1.weight_hh_l0_reverse.shape == W_hh_b.shape |
| 964 | |
| 965 | self.layer1.weight_ih_l0_reverse = nn.Parameter(torch.FloatTensor(W_ih_b)) |
| 966 | self.layer1.weight_hh_l0_reverse = nn.Parameter(torch.FloatTensor(W_hh_b)) |
| 967 | |
| 968 | b_f = np.concatenate( |
| 969 | [ |
| 970 | params["components"]["cell_fwd"]["bu"], |
| 971 | params["components"]["cell_fwd"]["bf"], |
| 972 | params["components"]["cell_fwd"]["bc"], |
| 973 | params["components"]["cell_fwd"]["bo"], |
| 974 | ], |
| 975 | axis=-1, |
| 976 | ).flatten() |
| 977 | |
| 978 | assert self.layer1.bias_ih_l0.shape == b_f.shape |