MCPcopy
hub / github.com/microsoft/Cream / FlopsEst

Class FlopsEst

Cream/lib/utils/flops_table.py:11–83  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

9
10
11class FlopsEst(object):
12 def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'):
13 self.block_num = len(model.blocks)
14 self.choice_num = len(model.blocks[0])
15 self.flops_dict = {}
16 self.params_dict = {}
17
18 if device == 'cpu':
19 model = model.cpu()
20 else:
21 model = model.cuda()
22
23 self.params_fixed = 0
24 self.flops_fixed = 0
25
26 input = torch.randn(input_shape)
27
28 flops, params = get_model_complexity_info(
29 model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
30 self.params_fixed += params / 1e6
31 self.flops_fixed += flops / 1e6
32
33 input = model.conv_stem(input)
34
35 for block_id, block in enumerate(model.blocks):
36 self.flops_dict[block_id] = {}
37 self.params_dict[block_id] = {}
38 for module_id, module in enumerate(block):
39 self.flops_dict[block_id][module_id] = {}
40 self.params_dict[block_id][module_id] = {}
41 for choice_id, choice in enumerate(module):
42 flops, params = get_model_complexity_info(choice, tuple(
43 input.shape[1:]), as_strings=False, print_per_layer_stat=False)
44 # Flops(M)
45 self.flops_dict[block_id][module_id][choice_id] = flops / 1e6
46 # Params(M)
47 self.params_dict[block_id][module_id][choice_id] = params / 1e6
48
49 input = choice(input)
50
51 # conv_last
52 flops, params = get_model_complexity_info(model.global_pool, tuple(
53 input.shape[1:]), as_strings=False, print_per_layer_stat=False)
54 self.params_fixed += params / 1e6
55 self.flops_fixed += flops / 1e6
56
57 input = model.global_pool(input)
58
59 # globalpool
60 flops, params = get_model_complexity_info(model.conv_head, tuple(
61 input.shape[1:]), as_strings=False, print_per_layer_stat=False)
62 self.params_fixed += params / 1e6
63 self.flops_fixed += flops / 1e6
64
65 # return params (M)
66 def get_params(self, arch):
67 params = 0
68 for block_id, block in enumerate(arch):

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected