(self, prototxt_path: str, caffemodel_path: str)
| 11 | |
| 12 | class CaffeParser(GraphBuilder): |
| 13 | def load_graph_and_format(self, prototxt_path: str, caffemodel_path: str) -> ppl_caffe_pb2.NetParameter: |
| 14 | if not is_file_exist(prototxt_path): |
| 15 | raise FileNotFoundError(f'file {prototxt_path} not exist, please check your file path') |
| 16 | elif not is_file_exist(caffemodel_path): |
| 17 | raise FileNotFoundError(f'file {caffemodel_path} not existm please check your file path') |
| 18 | network = ppl_caffe_pb2.NetParameter() |
| 19 | with open(prototxt_path) as f: |
| 20 | text_format.Merge(f.read(), network) |
| 21 | weight = ppl_caffe_pb2.NetParameter() |
| 22 | with open(caffemodel_path, 'rb') as f: |
| 23 | weight.ParseFromString(f.read()) |
| 24 | |
| 25 | network = de_inplace(network) |
| 26 | |
| 27 | for i in network.layer: |
| 28 | for j in weight.layer: |
| 29 | if i.name == j.name: |
| 30 | i.ClearField('blobs') |
| 31 | i.blobs.MergeFrom(j.blobs) |
| 32 | break |
| 33 | |
| 34 | network = merge_batchnorm_scale(network) |
| 35 | return network |
| 36 | |
| 37 | def build(self, prototxt_path: str, caffemodel_path: str) -> BaseGraph: |
| 38 | network = self.load_graph_and_format(prototxt_path, caffemodel_path) |
no test coverage detected