MCPcopy
hub / github.com/OpenPPL/ppq / get_input_shape

Function get_input_shape

ppq/parser/caffe/caffe_import_utils.py:85–112  ·  view source on GitHub ↗
(net_def: ppl_caffe_pb2.NetParameter)

Source from the content-addressed store, hash-verified

83 return graph
84
85def get_input_shape(net_def: ppl_caffe_pb2.NetParameter) -> Dict[str, list]:
86 # Only support one format input shape, not support mixed format
87 def layer_exist(layer_type):
88 return layer_type in [item.type for item in net_def.layer]
89
90 input_shape = {k: None for k in net_def.input}
91 # Given input shape use input_shape field
92 if len(net_def.input_shape) != 0:
93 for i, name in enumerate(net_def.input):
94 input_shape[name] = list(net_def.input_shape[i].dim)
95 # Given input shape use input_dim
96 # TODO: Here only support 4-D input
97 elif len(net_def.input_dim) != 0:
98 for i, name in enumerate(net_def.input):
99 input_shape[name] = list(net_def.input_dim[i * 4:(i + 1) * 4])
100 # Given input shape use input layer
101 elif layer_exist('Input'):
102 input_layer = [item for item in net_def.layer]
103 for layer in input_layer:
104 input_shape[layer.top[0]] = list(layer.input_param.shape.dim)
105 else:
106 raise TypeError('Unsupported network input format.')
107
108 for k, v in input_shape.items():
109 if v is None:
110 raise TypeError("shape of input '%s' is not specified." % k)
111
112 return input_shape
113
114def register_class(cls):
115 caffe_import_map[cls.__name__] = cls

Callers 1

buildMethod · 0.85

Calls 1

layer_existFunction · 0.85

Tested by

no test coverage detected