| 41 | self.aux_shape_dict = dict(zip(self.sym.list_auxiliary_states(), aux_shape)) |
| 42 | |
| 43 | def check_parameter_shapes(self, arg_params, aux_params, data_shape_dict, is_train=True): |
| 44 | for k in self.sym.list_arguments(): |
| 45 | if k in data_shape_dict or (False if is_train else 'label' in k): |
| 46 | continue |
| 47 | assert k in arg_params, k + ' not initialized' |
| 48 | assert arg_params[k].shape == self.arg_shape_dict[k], \ |
| 49 | 'shape inconsistent for ' + k + ' inferred ' + str(self.arg_shape_dict[k]) + ' provided ' + str( |
| 50 | arg_params[k].shape) |
| 51 | for k in self.sym.list_auxiliary_states(): |
| 52 | assert k in aux_params, k + ' not initialized' |
| 53 | assert aux_params[k].shape == self.aux_shape_dict[k], \ |
| 54 | 'shape inconsistent for ' + k + ' inferred ' + str(self.aux_shape_dict[k]) + ' provided ' + str( |
| 55 | aux_params[k].shape) |