(model_gen, arg)
| 575 | |
| 576 | |
| 577 | def Benchmark(model_gen, arg): |
| 578 | model, input_size = model_gen(arg.order, arg.cudnn_ws) |
| 579 | model.Proto().type = arg.net_type |
| 580 | model.Proto().num_workers = arg.num_workers |
| 581 | |
| 582 | # In order to be able to run everything without feeding more stuff, let's |
| 583 | # add the data and label blobs to the parameter initialization net as well. |
| 584 | if arg.order == "NCHW": |
| 585 | input_shape = [arg.batch_size, 3, input_size, input_size] |
| 586 | else: |
| 587 | input_shape = [arg.batch_size, input_size, input_size, 3] |
| 588 | if arg.model == "MLP": |
| 589 | input_shape = [arg.batch_size, input_size] |
| 590 | |
| 591 | model.param_init_net.GaussianFill( |
| 592 | [], |
| 593 | "data", |
| 594 | shape=input_shape, |
| 595 | mean=0.0, |
| 596 | std=1.0 |
| 597 | ) |
| 598 | model.param_init_net.UniformIntFill( |
| 599 | [], |
| 600 | "label", |
| 601 | shape=[arg.batch_size, ], |
| 602 | min=0, |
| 603 | max=999 |
| 604 | ) |
| 605 | |
| 606 | if arg.forward_only: |
| 607 | print('{}: running forward only.'.format(arg.model)) |
| 608 | else: |
| 609 | print('{}: running forward-backward.'.format(arg.model)) |
| 610 | model.AddGradientOperators(["loss"]) |
| 611 | AddParameterUpdate(model) |
| 612 | if arg.order == 'NHWC': |
| 613 | print( |
| 614 | '==WARNING==\n' |
| 615 | 'NHWC order with CuDNN may not be supported yet, so I might\n' |
| 616 | 'exit suddenly.' |
| 617 | ) |
| 618 | |
| 619 | if not arg.cpu: |
| 620 | model.param_init_net.RunAllOnGPU() |
| 621 | model.net.RunAllOnGPU() |
| 622 | |
| 623 | if arg.engine: |
| 624 | for op in model.net.Proto().op: |
| 625 | op.engine = arg.engine |
| 626 | |
| 627 | if arg.dump_model: |
| 628 | # Writes out the pbtxt for benchmarks on e.g. Android |
| 629 | with open( |
| 630 | "{0}_init_batch_{1}.pbtxt".format(arg.model, arg.batch_size), "w" |
| 631 | ) as fid: |
| 632 | fid.write(str(model.param_init_net.Proto())) |
| 633 | with open("{0}.pbtxt".format(arg.model, arg.batch_size), "w") as fid: |
| 634 | fid.write(str(model.net.Proto())) |
no test coverage detected
searching dependent graphs…