Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
(model)
| 76 | |
| 77 | |
| 78 | def unwrap_model(model): |
| 79 | """ |
| 80 | Unwrap a model from a DataParallel or DistributedDataParallel wrapper. |
| 81 | """ |
| 82 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): |
| 83 | return model.module |
| 84 | else: |
| 85 | return model |
| 86 | |
| 87 | |
| 88 | def get_predicted_classnames(logprobs, k, class_id_to_name): |
no outgoing calls
no test coverage detected