(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None)
| 57 | Instead they are initialized to 0 and can be set by set_states() |
| 58 | """ |
| 59 | def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), |
| 60 | logger=logging, context=ctx.cpu(), work_load_list=None, |
| 61 | fixed_param_names=None, state_names=None): |
| 62 | super(Module, self).__init__(logger=logger) |
| 63 | |
| 64 | if isinstance(context, ctx.Context): |
| 65 | context = [context] |
| 66 | self._context = context |
| 67 | if work_load_list is None: |
| 68 | work_load_list = [1] * len(self._context) |
| 69 | assert len(work_load_list) == len(self._context) |
| 70 | self._work_load_list = work_load_list |
| 71 | |
| 72 | self._symbol = symbol |
| 73 | |
| 74 | data_names = list(data_names) if data_names is not None else [] |
| 75 | label_names = list(label_names) if label_names is not None else [] |
| 76 | state_names = list(state_names) if state_names is not None else [] |
| 77 | fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else [] |
| 78 | |
| 79 | _check_input_names(symbol, data_names, "data", True) |
| 80 | _check_input_names(symbol, label_names, "label", False) |
| 81 | _check_input_names(symbol, state_names, "state", True) |
| 82 | _check_input_names(symbol, fixed_param_names, "fixed_param", True) |
| 83 | |
| 84 | arg_names = symbol.list_arguments() |
| 85 | input_names = data_names + label_names + state_names |
| 86 | self._param_names = [x for x in arg_names if x not in input_names] |
| 87 | self._fixed_param_names = fixed_param_names |
| 88 | self._aux_names = symbol.list_auxiliary_states() |
| 89 | self._data_names = data_names |
| 90 | self._label_names = label_names |
| 91 | self._state_names = state_names |
| 92 | self._output_names = symbol.list_outputs() |
| 93 | |
| 94 | self._arg_params = None |
| 95 | self._aux_params = None |
| 96 | self._params_dirty = False |
| 97 | |
| 98 | self._optimizer = None |
| 99 | self._kvstore = None |
| 100 | self._update_on_kvstore = None |
| 101 | self._updater = None |
| 102 | self._preload_opt_states = None |
| 103 | self._grad_req = None |
| 104 | |
| 105 | self._exec_group = None |
| 106 | self._data_shapes = None |
| 107 | self._label_shapes = None |
| 108 | |
| 109 | @staticmethod |
| 110 | def load(prefix, epoch, load_optimizer_states=False, **kwargs): |
no test coverage detected