MCPcopy
hub / github.com/alibaba/EasyCV / count_flop

Function count_flop

tools/analyze_tools/count_flops.py:49–109  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

47
48
49def count_flop():
50 args = parse_args()
51
52 device = 'cuda'
53 cfg = mmcv_config_fromfile(args.config)
54
55 # dynamic adapt mmdet models
56 dynamic_adapt_for_mmlab(cfg)
57
58 model = build_model(cfg.model)
59 model.to(device)
60 model.eval()
61
62 if cfg.data.get('val', None) is not None:
63 cfg.data.val.pop('imgs_per_gpu', None) # pop useless params
64 data_cfg = cfg.data.val
65 else:
66 data_cfg = cfg.data.train
67
68 if is_dali_dataset_type(data_cfg['type']):
69 data_cfg.distributed = False
70 data_cfg.batch_size = 1
71 data_cfg.workers_per_gpu = 1
72 dataset = build_dataset(data_cfg)
73 data_loader = dataset.get_dataloader()
74 else:
75 dataset = build_dataset(data_cfg)
76 data_loader = build_dataloader(
77 dataset, imgs_per_gpu=1, workers_per_gpu=0)
78
79 handlers = {} # mapping from operator names to handles.
80 counts = Counter()
81 gflop_unit = 1e9
82 total_flops = []
83 for idx, data in zip(tqdm.trange(args.repeat_num), data_loader):
84 # use scatter_kwargs to unpack DataContainer data for raw torch.nn.module
85 _, kwargs = scatter_kwargs(None, data, [0])
86 kwargs[0].update({'mode': 'test'})
87 inputs = flatten_inputs(model, kwargs[0])
88
89 # Provides access to per-submodule model flop count obtained by
90 # tracing a model with pytorch's jit tracing functionality.
91 # So models that donot support jit tracing may fail.
92 flops = FlopCountAnalysis(model, inputs)
93 flops.set_op_handle(**handlers)
94 if idx > 0:
95 flops.unsupported_ops_warnings(False).uncalled_modules_warnings(
96 False)
97 counts += flops.by_operator()
98 total_flops.append(flops.total())
99
100 print('Flops from only one sample is:\n' + flop_count_table(flops))
101 ops_show = PrettyTable()
102 ops_show.field_names = ['operator type', 'Gflops']
103 for k, v in counts.items():
104 ops_show.add_row([k, round(v / (idx + 1) / gflop_unit, 3)])
105 print('Average Gflops of each operator type is:')
106 print(ops_show)

Callers 1

count_flops.pyFile · 0.85

Calls 12

mmcv_config_fromfileFunction · 0.90
dynamic_adapt_for_mmlabFunction · 0.90
build_modelFunction · 0.90
is_dali_dataset_typeFunction · 0.90
build_datasetFunction · 0.90
build_dataloaderFunction · 0.90
flatten_inputsFunction · 0.85
parse_argsFunction · 0.70
toMethod · 0.45
getMethod · 0.45
get_dataloaderMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected