MCPcopy
hub / github.com/msracver/Deformable-ConvNets / bind

Method bind

fpn/core/module.py:802–856  ·  view source on GitHub ↗
(self, data_shapes, label_shapes=None, for_training=True,
             inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write')

Source from the content-addressed store, hash-verified

800 self.params_initialized = True
801
802 def bind(self, data_shapes, label_shapes=None, for_training=True,
803 inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
804 # in case we already initialized params, keep it
805 if self.params_initialized:
806 arg_params, aux_params = self.get_params()
807
808 # force rebinding is typically used when one want to switch from
809 # training to prediction phase.
810 if force_rebind:
811 self._reset_bind()
812
813 if self.binded:
814 self.logger.warning('Already binded, ignoring bind()')
815 return
816
817 assert shared_module is None, 'shared_module for MutableModule is not supported'
818
819 self.for_training = for_training
820 self.inputs_need_grad = inputs_need_grad
821 self.binded = True
822
823 max_shapes_dict = dict()
824 if self._max_data_shapes is not None:
825 max_shapes_dict.update(dict(self._max_data_shapes[0]))
826 if self._max_label_shapes is not None:
827 max_shapes_dict.update(dict(self._max_label_shapes[0]))
828
829 max_data_shapes = list()
830 for name, shape in data_shapes[0]:
831 if name in max_shapes_dict:
832 max_data_shapes.append((name, max_shapes_dict[name]))
833 else:
834 max_data_shapes.append((name, shape))
835
836 max_label_shapes = list()
837 if not label_shapes.count(None) == len(label_shapes):
838 for name, shape in label_shapes[0]:
839 if name in max_shapes_dict:
840 max_label_shapes.append((name, max_shapes_dict[name]))
841 else:
842 max_label_shapes.append((name, shape))
843
844 if len(max_label_shapes) == 0:
845 max_label_shapes = None
846
847 module = Module(self._symbol, self._data_names, self._label_names, logger=self.logger,
848 context=self._context, work_load_list=self._work_load_list,
849 fixed_param_names=self._fixed_param_names)
850 module.bind([max_data_shapes for _ in range(len(self._context))], [max_label_shapes for _ in range(len(self._context))],
851 for_training, inputs_need_grad, force_rebind=False, shared_module=None)
852 self._curr_module = module
853
854 # copy back saved params, if already initialized
855 if self.params_initialized:
856 self.set_params(arg_params, aux_params)
857
858 def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
859 """Save current progress to checkpoint.

Callers 1

fitMethod · 0.95

Calls 6

get_paramsMethod · 0.95
_reset_bindMethod · 0.95
bindMethod · 0.95
ModuleClass · 0.70
updateMethod · 0.45
set_paramsMethod · 0.45

Tested by

no test coverage detected