Module is a basic module that wrap a `Symbol`. It is functionally the same as the `FeedForward` model, except under the module API. Parameters ---------- symbol : Symbol data_names : list of str Default is `('data')` for a typical model used in image classification.
| 33 | |
| 34 | |
| 35 | class Module(BaseModule): |
| 36 | """Module is a basic module that wrap a `Symbol`. It is functionally the same |
| 37 | as the `FeedForward` model, except under the module API. |
| 38 | |
| 39 | Parameters |
| 40 | ---------- |
| 41 | symbol : Symbol |
| 42 | data_names : list of str |
| 43 | Default is `('data')` for a typical model used in image classification. |
| 44 | label_names : list of str |
| 45 | Default is `('softmax_label')` for a typical model used in image |
| 46 | classification. |
| 47 | logger : Logger |
| 48 | Default is `logging`. |
| 49 | context : Context or list of Context |
| 50 | Default is `cpu()`. |
| 51 | work_load_list : list of number |
| 52 | Default `None`, indicating uniform workload. |
| 53 | fixed_param_names: list of str |
| 54 | Default `None`, indicating no network parameters are fixed. |
| 55 | state_names : list of str |
| 56 | states are similar to data and label, but not provided by data iterator. |
| 57 | Instead they are initialized to 0 and can be set by set_states() |
| 58 | """ |
| 59 | def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), |
| 60 | logger=logging, context=ctx.cpu(), work_load_list=None, |
| 61 | fixed_param_names=None, state_names=None): |
| 62 | super(Module, self).__init__(logger=logger) |
| 63 | |
| 64 | if isinstance(context, ctx.Context): |
| 65 | context = [context] |
| 66 | self._context = context |
| 67 | if work_load_list is None: |
| 68 | work_load_list = [1] * len(self._context) |
| 69 | assert len(work_load_list) == len(self._context) |
| 70 | self._work_load_list = work_load_list |
| 71 | |
| 72 | self._symbol = symbol |
| 73 | |
| 74 | data_names = list(data_names) if data_names is not None else [] |
| 75 | label_names = list(label_names) if label_names is not None else [] |
| 76 | state_names = list(state_names) if state_names is not None else [] |
| 77 | fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else [] |
| 78 | |
| 79 | _check_input_names(symbol, data_names, "data", True) |
| 80 | _check_input_names(symbol, label_names, "label", False) |
| 81 | _check_input_names(symbol, state_names, "state", True) |
| 82 | _check_input_names(symbol, fixed_param_names, "fixed_param", True) |
| 83 | |
| 84 | arg_names = symbol.list_arguments() |
| 85 | input_names = data_names + label_names + state_names |
| 86 | self._param_names = [x for x in arg_names if x not in input_names] |
| 87 | self._fixed_param_names = fixed_param_names |
| 88 | self._aux_names = symbol.list_auxiliary_states() |
| 89 | self._data_names = data_names |
| 90 | self._label_names = label_names |
| 91 | self._state_names = state_names |
| 92 | self._output_names = symbol.list_outputs() |