MCPcopy
hub / github.com/whai362/PVT / main

Function main

detection/benchmark.py:42–111  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

40
41
42def main():
43 args = parse_args()
44
45 cfg = Config.fromfile(args.config)
46 if args.cfg_options is not None:
47 cfg.merge_from_dict(args.cfg_options)
48 # import modules from string list.
49 if cfg.get('custom_imports', None):
50 from mmcv.utils import import_modules_from_strings
51 import_modules_from_strings(**cfg['custom_imports'])
52 # set cudnn_benchmark
53 if cfg.get('cudnn_benchmark', False):
54 torch.backends.cudnn.benchmark = True
55 cfg.model.pretrained = None
56 cfg.data.test.test_mode = True
57
58 # build the dataloader
59 samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
60 if samples_per_gpu > 1:
61 # Replace 'ImageToTensor' to 'DefaultFormatBundle'
62 cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
63 dataset = build_dataset(cfg.data.test)
64 data_loader = build_dataloader(
65 dataset,
66 samples_per_gpu=1,
67 workers_per_gpu=cfg.data.workers_per_gpu,
68 dist=False,
69 shuffle=False)
70
71 # build the model and load checkpoint
72 cfg.model.train_cfg = None
73 model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
74 fp16_cfg = cfg.get('fp16', None)
75 if fp16_cfg is not None:
76 wrap_fp16_model(model)
77 load_checkpoint(model, args.checkpoint, map_location='cpu')
78 if args.fuse_conv_bn:
79 model = fuse_conv_bn(model)
80
81 model = MMDataParallel(model, device_ids=[0])
82
83 model.eval()
84
85 # the first several iterations may be very slow so skip them
86 num_warmup = 5
87 pure_inf_time = 0
88
89 # benchmark with 2000 image and take the average
90 for i, data in enumerate(data_loader):
91
92 torch.cuda.synchronize()
93 start_time = time.perf_counter()
94
95 with torch.no_grad():
96 model(return_loss=False, rescale=True, **data)
97
98 torch.cuda.synchronize()
99 elapsed = time.perf_counter() - start_time

Callers 1

benchmark.pyFile · 0.70

Calls 3

build_datasetFunction · 0.90
printFunction · 0.85
parse_argsFunction · 0.70

Tested by

no test coverage detected