MCPcopy
hub / github.com/ddbourgin/numpy-ml / __init__

Method __init__

numpy_ml/tests/nn_torch_models.py:921–998  ·  view source on GitHub ↗
(self, n_in, n_out, params, **kwargs)

Source from the content-addressed store, hash-verified

919
920class TorchBidirectionalLSTM(nn.Module):
921 def __init__(self, n_in, n_out, params, **kwargs):
922 super(TorchBidirectionalLSTM, self).__init__()
923
924 self.layer1 = nn.LSTM(
925 input_size=n_in,
926 hidden_size=n_out,
927 num_layers=1,
928 bidirectional=True,
929 bias=True,
930 )
931
932 Wiu = params["components"]["cell_fwd"]["Wu"][n_out:, :].T
933 Wif = params["components"]["cell_fwd"]["Wf"][n_out:, :].T
934 Wic = params["components"]["cell_fwd"]["Wc"][n_out:, :].T
935 Wio = params["components"]["cell_fwd"]["Wo"][n_out:, :].T
936 W_ih_f = np.vstack([Wiu, Wif, Wic, Wio])
937
938 Whu = params["components"]["cell_fwd"]["Wu"][:n_out, :].T
939 Whf = params["components"]["cell_fwd"]["Wf"][:n_out, :].T
940 Whc = params["components"]["cell_fwd"]["Wc"][:n_out, :].T
941 Who = params["components"]["cell_fwd"]["Wo"][:n_out, :].T
942 W_hh_f = np.vstack([Whu, Whf, Whc, Who])
943
944 assert self.layer1.weight_ih_l0.shape == W_ih_f.shape
945 assert self.layer1.weight_hh_l0.shape == W_hh_f.shape
946
947 self.layer1.weight_ih_l0 = nn.Parameter(torch.FloatTensor(W_ih_f))
948 self.layer1.weight_hh_l0 = nn.Parameter(torch.FloatTensor(W_hh_f))
949
950 Wiu = params["components"]["cell_bwd"]["Wu"][n_out:, :].T
951 Wif = params["components"]["cell_bwd"]["Wf"][n_out:, :].T
952 Wic = params["components"]["cell_bwd"]["Wc"][n_out:, :].T
953 Wio = params["components"]["cell_bwd"]["Wo"][n_out:, :].T
954 W_ih_b = np.vstack([Wiu, Wif, Wic, Wio])
955
956 Whu = params["components"]["cell_bwd"]["Wu"][:n_out, :].T
957 Whf = params["components"]["cell_bwd"]["Wf"][:n_out, :].T
958 Whc = params["components"]["cell_bwd"]["Wc"][:n_out, :].T
959 Who = params["components"]["cell_bwd"]["Wo"][:n_out, :].T
960 W_hh_b = np.vstack([Whu, Whf, Whc, Who])
961
962 assert self.layer1.weight_ih_l0_reverse.shape == W_ih_b.shape
963 assert self.layer1.weight_hh_l0_reverse.shape == W_hh_b.shape
964
965 self.layer1.weight_ih_l0_reverse = nn.Parameter(torch.FloatTensor(W_ih_b))
966 self.layer1.weight_hh_l0_reverse = nn.Parameter(torch.FloatTensor(W_hh_b))
967
968 b_f = np.concatenate(
969 [
970 params["components"]["cell_fwd"]["bu"],
971 params["components"]["cell_fwd"]["bf"],
972 params["components"]["cell_fwd"]["bc"],
973 params["components"]["cell_fwd"]["bo"],
974 ],
975 axis=-1,
976 ).flatten()
977
978 assert self.layer1.bias_ih_l0.shape == b_f.shape

Callers

nothing calls this directly

Calls 1

__init__Method · 0.45

Tested by

no test coverage detected