MCPcopy Index your code
hub / github.com/pytorch/pytorch / Benchmark

Function Benchmark

caffe2/python/convnet_benchmarks.py:577–640  ·  view source on GitHub ↗
(model_gen, arg)

Source from the content-addressed store, hash-verified

575
576
577def Benchmark(model_gen, arg):
578 model, input_size = model_gen(arg.order, arg.cudnn_ws)
579 model.Proto().type = arg.net_type
580 model.Proto().num_workers = arg.num_workers
581
582 # In order to be able to run everything without feeding more stuff, let's
583 # add the data and label blobs to the parameter initialization net as well.
584 if arg.order == "NCHW":
585 input_shape = [arg.batch_size, 3, input_size, input_size]
586 else:
587 input_shape = [arg.batch_size, input_size, input_size, 3]
588 if arg.model == "MLP":
589 input_shape = [arg.batch_size, input_size]
590
591 model.param_init_net.GaussianFill(
592 [],
593 "data",
594 shape=input_shape,
595 mean=0.0,
596 std=1.0
597 )
598 model.param_init_net.UniformIntFill(
599 [],
600 "label",
601 shape=[arg.batch_size, ],
602 min=0,
603 max=999
604 )
605
606 if arg.forward_only:
607 print('{}: running forward only.'.format(arg.model))
608 else:
609 print('{}: running forward-backward.'.format(arg.model))
610 model.AddGradientOperators(["loss"])
611 AddParameterUpdate(model)
612 if arg.order == 'NHWC':
613 print(
614 '==WARNING==\n'
615 'NHWC order with CuDNN may not be supported yet, so I might\n'
616 'exit suddenly.'
617 )
618
619 if not arg.cpu:
620 model.param_init_net.RunAllOnGPU()
621 model.net.RunAllOnGPU()
622
623 if arg.engine:
624 for op in model.net.Proto().op:
625 op.engine = arg.engine
626
627 if arg.dump_model:
628 # Writes out the pbtxt for benchmarks on e.g. Android
629 with open(
630 "{0}_init_batch_{1}.pbtxt".format(arg.model, arg.batch_size), "w"
631 ) as fid:
632 fid.write(str(model.param_init_net.Proto()))
633 with open("{0}.pbtxt".format(arg.model, arg.batch_size), "w") as fid:
634 fid.write(str(model.net.Proto()))

Callers 1

Calls 6

AddParameterUpdateFunction · 0.70
ProtoMethod · 0.45
formatMethod · 0.45
AddGradientOperatorsMethod · 0.45
RunAllOnGPUMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…