(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write')
| 800 | self.params_initialized = True |
| 801 | |
| 802 | def bind(self, data_shapes, label_shapes=None, for_training=True, |
| 803 | inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'): |
| 804 | # in case we already initialized params, keep it |
| 805 | if self.params_initialized: |
| 806 | arg_params, aux_params = self.get_params() |
| 807 | |
| 808 | # force rebinding is typically used when one want to switch from |
| 809 | # training to prediction phase. |
| 810 | if force_rebind: |
| 811 | self._reset_bind() |
| 812 | |
| 813 | if self.binded: |
| 814 | self.logger.warning('Already binded, ignoring bind()') |
| 815 | return |
| 816 | |
| 817 | assert shared_module is None, 'shared_module for MutableModule is not supported' |
| 818 | |
| 819 | self.for_training = for_training |
| 820 | self.inputs_need_grad = inputs_need_grad |
| 821 | self.binded = True |
| 822 | |
| 823 | max_shapes_dict = dict() |
| 824 | if self._max_data_shapes is not None: |
| 825 | max_shapes_dict.update(dict(self._max_data_shapes[0])) |
| 826 | if self._max_label_shapes is not None: |
| 827 | max_shapes_dict.update(dict(self._max_label_shapes[0])) |
| 828 | |
| 829 | max_data_shapes = list() |
| 830 | for name, shape in data_shapes[0]: |
| 831 | if name in max_shapes_dict: |
| 832 | max_data_shapes.append((name, max_shapes_dict[name])) |
| 833 | else: |
| 834 | max_data_shapes.append((name, shape)) |
| 835 | |
| 836 | max_label_shapes = list() |
| 837 | if not label_shapes.count(None) == len(label_shapes): |
| 838 | for name, shape in label_shapes[0]: |
| 839 | if name in max_shapes_dict: |
| 840 | max_label_shapes.append((name, max_shapes_dict[name])) |
| 841 | else: |
| 842 | max_label_shapes.append((name, shape)) |
| 843 | |
| 844 | if len(max_label_shapes) == 0: |
| 845 | max_label_shapes = None |
| 846 | |
| 847 | module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger, |
| 848 | context=self._context, work_load_list=self._work_load_list, |
| 849 | fixed_param_names=self._fixed_param_names) |
| 850 | module.bind([max_data_shapes for _ in range(len(self._context))], [max_label_shapes for _ in range(len(self._context))], |
| 851 | for_training, inputs_need_grad, force_rebind=False, shared_module=None) |
| 852 | self._curr_module = module |
| 853 | |
| 854 | # copy back saved params, if already initialized |
| 855 | if self.params_initialized: |
| 856 | self.set_params(arg_params, aux_params) |
| 857 | |
| 858 | def save_checkpoint(self, prefix, epoch, save_optimizer_states=False): |
| 859 | """Save current progress to checkpoint. |
no test coverage detected