| 1022 | |
| 1023 | |
| 1024 | def forward(self, data_batch, is_train=None): |
| 1025 | assert self.binded and self.params_initialized |
| 1026 | |
| 1027 | # get current_shapes |
| 1028 | if self._curr_module.label_shapes is not None: |
| 1029 | current_shapes = [dict(self._curr_module.data_shapes[i] + self._curr_module.label_shapes[i]) for i in range(len(self._context))] |
| 1030 | else: |
| 1031 | current_shapes = [dict(self._curr_module.data_shapes[i]) for i in range(len(self._context))] |
| 1032 | |
| 1033 | # get input_shapes |
| 1034 | if is_train: |
| 1035 | input_shapes = [dict(data_batch.provide_data[i] + data_batch.provide_label[i]) for i in range(len(self._context))] |
| 1036 | else: |
| 1037 | input_shapes = [dict(data_batch.provide_data[i]) for i in range(len(data_batch.provide_data))] |
| 1038 | |
| 1039 | # decide if shape changed |
| 1040 | shape_changed = len(current_shapes) != len(input_shapes) |
| 1041 | for pre, cur in zip(current_shapes, input_shapes): |
| 1042 | for k, v in pre.items(): |
| 1043 | if v != cur[k]: |
| 1044 | shape_changed = True |
| 1045 | |
| 1046 | if shape_changed: |
| 1047 | # self._curr_module.reshape(data_batch.provide_data, data_batch.provide_label) |
| 1048 | module = Module(self._symbol, self._data_names, self._label_names, |
| 1049 | logger=self.logger, context=[self._context[i] for i in range(len(data_batch.provide_data))], |
| 1050 | work_load_list=self._work_load_list, |
| 1051 | fixed_param_names=self._fixed_param_names) |
| 1052 | module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training, |
| 1053 | self._curr_module.inputs_need_grad, force_rebind=False, |
| 1054 | shared_module=self._curr_module) |
| 1055 | self._curr_module = module |
| 1056 | |
| 1057 | self._curr_module.forward(data_batch, is_train=is_train) |
| 1058 | |
| 1059 | def backward(self, out_grads=None): |
| 1060 | assert self.binded and self.params_initialized |