(args)
| 281 | |
| 282 | |
| 283 | def Optimize(args): |
| 284 | init_net = caffe2_pb2.NetDef() |
| 285 | predict_net = caffe2_pb2.NetDef() |
| 286 | init_net.ParseFromString(args.init_net.read()) |
| 287 | predict_net.ParseFromString(args.pred_net.read()) |
| 288 | |
| 289 | workspace.ResetWorkspace() |
| 290 | workspace.RunNetOnce(init_net) |
| 291 | param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()} |
| 292 | |
| 293 | external_inputs = {} |
| 294 | external_outputs = {} |
| 295 | if args.verify_input: |
| 296 | value_info = json.load(args.verify_input) |
| 297 | input_shapes = {k : v[-1] for (k, v) in value_info.items()} |
| 298 | print("input info: {}".format(input_shapes)) |
| 299 | for k, v in input_shapes.items(): |
| 300 | external_inputs[k] = np.random.randn(*v).astype(np.float32) |
| 301 | workspace.FeedBlob(k, external_inputs[k]) |
| 302 | workspace.RunNetOnce(predict_net) |
| 303 | for o in predict_net.external_output: |
| 304 | external_outputs[o] = workspace.FetchBlob(o) |
| 305 | |
| 306 | if args.fuse_mul_add: |
| 307 | predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict) |
| 308 | if args.fuse_bn: |
| 309 | predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False) |
| 310 | if args.fuse_conv_relu: |
| 311 | predict_net = fuse_conv_relu(predict_net) |
| 312 | |
| 313 | external_outputs_opt = {} |
| 314 | if args.verify_input: |
| 315 | workspace.ResetWorkspace() |
| 316 | device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU) |
| 317 | with core.DeviceScope(device_option): |
| 318 | for k, v in param_dict.items(): |
| 319 | workspace.FeedBlob(k, v, device_option) |
| 320 | for k, v in external_inputs.items(): |
| 321 | workspace.FeedBlob(k, v, device_option) |
| 322 | workspace.RunNetOnce(predict_net) |
| 323 | for o in predict_net.external_output: |
| 324 | external_outputs_opt[o] = workspace.FetchBlob(o) |
| 325 | assert np.allclose(external_outputs[o], |
| 326 | external_outputs_opt[o], |
| 327 | atol=1e-3, |
| 328 | rtol=1e-3) |
| 329 | |
| 330 | for i, o in enumerate(predict_net.op): |
| 331 | print("op[{}]: {}".format(i, o.type)) |
| 332 | init_net = gen_init_net_from_blobs(param_dict) |
| 333 | with open('init_net.pb', 'wb') as f: |
| 334 | f.write(init_net.SerializeToString()) |
| 335 | with open('predict_net.pb', 'wb') as f: |
| 336 | f.write(predict_net.SerializeToString()) |
| 337 | |
| 338 | if __name__ == '__main__': |
| 339 | args = GetArgumentParser().parse_args() |
no test coverage detected
searching dependent graphs…