MCPcopy
hub / github.com/msracver/Deformable-ConvNets / forward

Method forward

fpn/core/module.py:1024–1057  ·  view source on GitHub ↗
(self, data_batch, is_train=None)

Source from the content-addressed store, hash-verified

1022
1023
1024 def forward(self, data_batch, is_train=None):
1025 assert self.binded and self.params_initialized
1026
1027 # get current_shapes
1028 if self._curr_module.label_shapes is not None:
1029 current_shapes = [dict(self._curr_module.data_shapes[i] + self._curr_module.label_shapes[i]) for i in range(len(self._context))]
1030 else:
1031 current_shapes = [dict(self._curr_module.data_shapes[i]) for i in range(len(self._context))]
1032
1033 # get input_shapes
1034 if is_train:
1035 input_shapes = [dict(data_batch.provide_data[i] + data_batch.provide_label[i]) for i in range(len(self._context))]
1036 else:
1037 input_shapes = [dict(data_batch.provide_data[i]) for i in range(len(data_batch.provide_data))]
1038
1039 # decide if shape changed
1040 shape_changed = len(current_shapes) != len(input_shapes)
1041 for pre, cur in zip(current_shapes, input_shapes):
1042 for k, v in pre.items():
1043 if v != cur[k]:
1044 shape_changed = True
1045
1046 if shape_changed:
1047 # self._curr_module.reshape(data_batch.provide_data, data_batch.provide_label)
1048 module = Module(self._symbol, self._data_names, self._label_names,
1049 logger=self.logger, context=[self._context[i] for i in range(len(data_batch.provide_data))],
1050 work_load_list=self._work_load_list,
1051 fixed_param_names=self._fixed_param_names)
1052 module.bind(data_batch.provide_data, data_batch.provide_label, self._curr_module.for_training,
1053 self._curr_module.inputs_need_grad, force_rebind=False,
1054 shared_module=self._curr_module)
1055 self._curr_module = module
1056
1057 self._curr_module.forward(data_batch, is_train=is_train)
1058
1059 def backward(self, out_grads=None):
1060 assert self.binded and self.params_initialized

Callers

nothing calls this directly

Calls 3

bindMethod · 0.95
ModuleClass · 0.70
forwardMethod · 0.45

Tested by

no test coverage detected