(model)
| 4 | |
| 5 | |
| 6 | def count_model_parameters(model): |
| 7 | if not isinstance(model, torch.nn.Module): |
| 8 | return 0, 0 |
| 9 | name = f"{model.__class__.__name__} {model.__class__}" |
| 10 | num = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 11 | size = num * 4.0 / 1024.0 / 1024.0 # float32, MB |
| 12 | logging.info(f"#param of {name} is {num} = {size:.1f} MB (float32)") |
| 13 | return num, size |