()
| 47 | |
| 48 | |
| 49 | def count_flop(): |
| 50 | args = parse_args() |
| 51 | |
| 52 | device = 'cuda' |
| 53 | cfg = mmcv_config_fromfile(args.config) |
| 54 | |
| 55 | # dynamic adapt mmdet models |
| 56 | dynamic_adapt_for_mmlab(cfg) |
| 57 | |
| 58 | model = build_model(cfg.model) |
| 59 | model.to(device) |
| 60 | model.eval() |
| 61 | |
| 62 | if cfg.data.get('val', None) is not None: |
| 63 | cfg.data.val.pop('imgs_per_gpu', None) # pop useless params |
| 64 | data_cfg = cfg.data.val |
| 65 | else: |
| 66 | data_cfg = cfg.data.train |
| 67 | |
| 68 | if is_dali_dataset_type(data_cfg['type']): |
| 69 | data_cfg.distributed = False |
| 70 | data_cfg.batch_size = 1 |
| 71 | data_cfg.workers_per_gpu = 1 |
| 72 | dataset = build_dataset(data_cfg) |
| 73 | data_loader = dataset.get_dataloader() |
| 74 | else: |
| 75 | dataset = build_dataset(data_cfg) |
| 76 | data_loader = build_dataloader( |
| 77 | dataset, imgs_per_gpu=1, workers_per_gpu=0) |
| 78 | |
| 79 | handlers = {} # mapping from operator names to handles. |
| 80 | counts = Counter() |
| 81 | gflop_unit = 1e9 |
| 82 | total_flops = [] |
| 83 | for idx, data in zip(tqdm.trange(args.repeat_num), data_loader): |
| 84 | # use scatter_kwargs to unpack DataContainer data for raw torch.nn.module |
| 85 | _, kwargs = scatter_kwargs(None, data, [0]) |
| 86 | kwargs[0].update({'mode': 'test'}) |
| 87 | inputs = flatten_inputs(model, kwargs[0]) |
| 88 | |
| 89 | # Provides access to per-submodule model flop count obtained by |
| 90 | # tracing a model with pytorch's jit tracing functionality. |
| 91 | # So models that donot support jit tracing may fail. |
| 92 | flops = FlopCountAnalysis(model, inputs) |
| 93 | flops.set_op_handle(**handlers) |
| 94 | if idx > 0: |
| 95 | flops.unsupported_ops_warnings(False).uncalled_modules_warnings( |
| 96 | False) |
| 97 | counts += flops.by_operator() |
| 98 | total_flops.append(flops.total()) |
| 99 | |
| 100 | print('Flops from only one sample is:\n' + flop_count_table(flops)) |
| 101 | ops_show = PrettyTable() |
| 102 | ops_show.field_names = ['operator type', 'Gflops'] |
| 103 | for k, v in counts.items(): |
| 104 | ops_show.add_row([k, round(v / (idx + 1) / gflop_unit, 3)]) |
| 105 | print('Average Gflops of each operator type is:') |
| 106 | print(ops_show) |
no test coverage detected