(model, tsize)
| 17 | |
| 18 | |
| 19 | def get_model_info(model, tsize): |
| 20 | |
| 21 | stride = 64 |
| 22 | img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) |
| 23 | flops, params = profile(deepcopy(model), inputs=(img,), verbose=False) |
| 24 | params /= 1e6 |
| 25 | flops /= 1e9 |
| 26 | flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops |
| 27 | info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops) |
| 28 | return info |
| 29 | |
| 30 | |
| 31 | def fuse_conv_and_bn(conv, bn): |