Model summary. Args: model: Model instance or model name.
(model: torch.nn.Module)
| 50 | |
| 51 | |
| 52 | def model_summary(model: torch.nn.Module) -> str: |
| 53 | """Model summary. |
| 54 | |
| 55 | Args: |
| 56 | model: Model instance or model name. |
| 57 | """ |
| 58 | message = "Model structure:\n" |
| 59 | message += str(model) |
| 60 | |
| 61 | tot_params, num_params = 0, 0 |
| 62 | for name, param in model.named_parameters(): |
| 63 | print( |
| 64 | "name: {}, dtype: {}, device: {}, trainable: {}, shape: {}, numel: {}".format( |
| 65 | name, param.dtype, param.device, param.requires_grad, param.shape, param.numel() |
| 66 | ) |
| 67 | ) |
| 68 | tot_params += param.numel() |
| 69 | if param.requires_grad: |
| 70 | num_params += param.numel() |
| 71 | |
| 72 | percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) |
| 73 | tot_params = get_human_readable_count(tot_params) |
| 74 | num_params = get_human_readable_count(num_params) |
| 75 | message += "\n\nModel summary:\n" |
| 76 | message += f" Class Name: {model.__class__.__name__}\n" |
| 77 | message += f" Total Number of model parameters: {tot_params}\n" |
| 78 | message += f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" |
| 79 | |
| 80 | dtype = next(iter(model.parameters())).dtype |
| 81 | message += f" Type: {dtype}" |
| 82 | return message |
no test coverage detected
searching dependent graphs…