(model, config)
| 16 | from model.model_minimind import MiniMindForCausalLM |
| 17 | |
| 18 | def get_model_params(model, config): |
| 19 | total = sum(p.numel() for p in model.parameters()) / 1e6 |
| 20 | n_routed = getattr(config, 'n_routed_experts', getattr(config, 'num_experts', 0)) |
| 21 | n_active = getattr(config, 'num_experts_per_tok', 0) |
| 22 | n_shared = getattr(config, 'n_shared_experts', 0) |
| 23 | expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.experts.0.' in n) / 1e6 |
| 24 | shared_expert = sum(p.numel() for n, p in model.named_parameters() if 'mlp.shared_experts.0.' in n) / 1e6 |
| 25 | base = total - (expert * n_routed) - (shared_expert * n_shared) |
| 26 | active = base + (expert * n_active) + (shared_expert * n_shared) |
| 27 | if active < total: Logger(f'Model Params: {total:.2f}M-A{active:.2f}M') |
| 28 | else: Logger(f'Model Params: {total:.2f}M') |
| 29 | |
| 30 | |
| 31 | def is_main_process(): |
no test coverage detected