| 985 | |
| 986 | |
| 987 | class BidirectionalLSTM(ModuleBase): |
| 988 | def __init__( |
| 989 | self, |
| 990 | n_out, |
| 991 | act_fn=None, |
| 992 | gate_fn=None, |
| 993 | merge_mode="concat", |
| 994 | init="glorot_uniform", |
| 995 | optimizer=None, |
| 996 | ): |
| 997 | """ |
| 998 | A single bidirectional long short-term memory (LSTM) layer. |
| 999 | |
| 1000 | Parameters |
| 1001 | ---------- |
| 1002 | n_out : int |
| 1003 | The dimension of a single hidden state / output on a given timestep |
| 1004 | act_fn : :doc:`Activation <numpy_ml.neural_nets.activations>` object or None |
| 1005 | The activation function for computing ``A[t]``. If not specified, |
| 1006 | use :class:`~numpy_ml.neural_nets.activations.Tanh` by default. |
| 1007 | gate_fn : :doc:`Activation <numpy_ml.neural_nets.activations>` object or None |
| 1008 | The gate function for computing the update, forget, and output |
| 1009 | gates. If not specified, use |
| 1010 | :class:`~numpy_ml.neural_nets.activations.Sigmoid` by default. |
| 1011 | merge_mode : {"sum", "multiply", "concat", "average"} |
| 1012 | Mode by which outputs of the forward and backward LSTMs will be |
| 1013 | combined. Default is 'concat'. |
| 1014 | optimizer : str or :doc:`Optimizer <numpy_ml.neural_nets.optimizers>` object or None |
| 1015 | The optimization strategy to use when performing gradient updates |
| 1016 | within the `update` method. If None, use the |
| 1017 | :class:`~numpy_ml.neural_nets.optimizers.SGD` optimizer with |
| 1018 | default parameters. Default is None. |
| 1019 | init : {'glorot_normal', 'glorot_uniform', 'he_normal', 'he_uniform'} |
| 1020 | The weight initialization strategy. Default is 'glorot_uniform'. |
| 1021 | """ |
| 1022 | super().__init__() |
| 1023 | |
| 1024 | self.init = init |
| 1025 | self.n_in = None |
| 1026 | self.n_out = n_out |
| 1027 | self.optimizer = optimizer |
| 1028 | self.merge_mode = merge_mode |
| 1029 | self.act_fn = Tanh() if act_fn is None else act_fn |
| 1030 | self.gate_fn = Sigmoid() if gate_fn is None else gate_fn |
| 1031 | self._init_params() |
| 1032 | |
| 1033 | def _init_params(self): |
| 1034 | self.cell_fwd = LSTMCell( |
| 1035 | init=self.init, |
| 1036 | n_out=self.n_out, |
| 1037 | act_fn=self.act_fn, |
| 1038 | gate_fn=self.gate_fn, |
| 1039 | optimizer=self.optimizer, |
| 1040 | ) |
| 1041 | self.cell_bwd = LSTMCell( |
| 1042 | init=self.init, |
| 1043 | n_out=self.n_out, |
| 1044 | act_fn=self.act_fn, |
no outgoing calls