| 83 | return graph |
| 84 | |
| 85 | def get_input_shape(net_def: ppl_caffe_pb2.NetParameter) -> Dict[str, list]: |
| 86 | # Only support one format input shape, not support mixed format |
| 87 | def layer_exist(layer_type): |
| 88 | return layer_type in [item.type for item in net_def.layer] |
| 89 | |
| 90 | input_shape = {k: None for k in net_def.input} |
| 91 | # Given input shape use input_shape field |
| 92 | if len(net_def.input_shape) != 0: |
| 93 | for i, name in enumerate(net_def.input): |
| 94 | input_shape[name] = list(net_def.input_shape[i].dim) |
| 95 | # Given input shape use input_dim |
| 96 | # TODO: Here only support 4-D input |
| 97 | elif len(net_def.input_dim) != 0: |
| 98 | for i, name in enumerate(net_def.input): |
| 99 | input_shape[name] = list(net_def.input_dim[i * 4:(i + 1) * 4]) |
| 100 | # Given input shape use input layer |
| 101 | elif layer_exist('Input'): |
| 102 | input_layer = [item for item in net_def.layer] |
| 103 | for layer in input_layer: |
| 104 | input_shape[layer.top[0]] = list(layer.input_param.shape.dim) |
| 105 | else: |
| 106 | raise TypeError('Unsupported network input format.') |
| 107 | |
| 108 | for k, v in input_shape.items(): |
| 109 | if v is None: |
| 110 | raise TypeError("shape of input '%s' is not specified." % k) |
| 111 | |
| 112 | return input_shape |
| 113 | |
| 114 | def register_class(cls): |
| 115 | caffe_import_map[cls.__name__] = cls |