r"""Make sure data parameters are consistent during Data Parallel Mode. Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked.
(model, parallel_mode)
| 12 | |
| 13 | |
| 14 | def sync_model_param(model, parallel_mode): |
| 15 | r"""Make sure data parameters are consistent during Data Parallel Mode. |
| 16 | |
| 17 | Args: |
| 18 | model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. |
| 19 | parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked. |
| 20 | """ |
| 21 | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
| 22 | for param in model.parameters(): |
| 23 | ranks = gpc.get_ranks_in_group(parallel_mode) |
| 24 | dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) |
| 25 | |
| 26 | |
| 27 | def sync_model_param_within_tp(model): |
no test coverage detected