MCPcopy Index your code
hub / github.com/pytorch/pytorch / Optimize

Function Optimize

caffe2/python/ideep/transform_ideep_net.py:283–336  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

281
282
283def Optimize(args):
284 init_net = caffe2_pb2.NetDef()
285 predict_net = caffe2_pb2.NetDef()
286 init_net.ParseFromString(args.init_net.read())
287 predict_net.ParseFromString(args.pred_net.read())
288
289 workspace.ResetWorkspace()
290 workspace.RunNetOnce(init_net)
291 param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()}
292
293 external_inputs = {}
294 external_outputs = {}
295 if args.verify_input:
296 value_info = json.load(args.verify_input)
297 input_shapes = {k : v[-1] for (k, v) in value_info.items()}
298 print("input info: {}".format(input_shapes))
299 for k, v in input_shapes.items():
300 external_inputs[k] = np.random.randn(*v).astype(np.float32)
301 workspace.FeedBlob(k, external_inputs[k])
302 workspace.RunNetOnce(predict_net)
303 for o in predict_net.external_output:
304 external_outputs[o] = workspace.FetchBlob(o)
305
306 if args.fuse_mul_add:
307 predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
308 if args.fuse_bn:
309 predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False)
310 if args.fuse_conv_relu:
311 predict_net = fuse_conv_relu(predict_net)
312
313 external_outputs_opt = {}
314 if args.verify_input:
315 workspace.ResetWorkspace()
316 device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU)
317 with core.DeviceScope(device_option):
318 for k, v in param_dict.items():
319 workspace.FeedBlob(k, v, device_option)
320 for k, v in external_inputs.items():
321 workspace.FeedBlob(k, v, device_option)
322 workspace.RunNetOnce(predict_net)
323 for o in predict_net.external_output:
324 external_outputs_opt[o] = workspace.FetchBlob(o)
325 assert np.allclose(external_outputs[o],
326 external_outputs_opt[o],
327 atol=1e-3,
328 rtol=1e-3)
329
330 for i, o in enumerate(predict_net.op):
331 print("op[{}]: {}".format(i, o.type))
332 init_net = gen_init_net_from_blobs(param_dict)
333 with open('init_net.pb', 'wb') as f:
334 f.write(init_net.SerializeToString())
335 with open('predict_net.pb', 'wb') as f:
336 f.write(predict_net.SerializeToString())
337
338if __name__ == '__main__':
339 args = GetArgumentParser().parse_args()

Callers 1

Calls 11

fuse_mul_addFunction · 0.85
fuse_conv_reluFunction · 0.85
gen_init_net_from_blobsFunction · 0.85
astypeMethod · 0.80
fuse_bnFunction · 0.70
readMethod · 0.45
loadMethod · 0.45
itemsMethod · 0.45
formatMethod · 0.45
randnMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…