Get the model parallel group the caller rank belongs to.
()
| 86 | |
| 87 | |
| 88 | def get_model_parallel_group(): |
| 89 | """Get the model parallel group the caller rank belongs to.""" |
| 90 | assert _MODEL_PARALLEL_GROUP is not None, \ |
| 91 | 'model parallel group is not initialized' |
| 92 | return _MODEL_PARALLEL_GROUP |
| 93 | |
| 94 | |
| 95 | def get_data_parallel_group(): |
no outgoing calls
no test coverage detected