r"""Make sure data parameters are consistent during Data Parallel Mode. Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked. Note:
(model, parallel_mode)
| 39 | |
| 40 | |
| 41 | def sync_model_param(model, parallel_mode): |
| 42 | r"""Make sure data parameters are consistent during Data Parallel Mode. |
| 43 | |
| 44 | Args: |
| 45 | model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. |
| 46 | parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel mode to be checked. |
| 47 | |
| 48 | Note: |
| 49 | The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found |
| 50 | in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ |
| 51 | """ |
| 52 | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
| 53 | for param in model.parameters(): |
| 54 | ranks = gpc.get_ranks_in_group(parallel_mode) |
| 55 | dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) |
| 56 | |
| 57 | |
| 58 | def is_dp_rank_0(): |
no test coverage detected
searching dependent graphs…