(config: DictConfig)
| 199 | |
| 200 | |
| 201 | def init_model_parallel_config(config: DictConfig) -> ModelParallelConfig: |
| 202 | # TODO(sgm): check how to disable megatron timers |
| 203 | timers = FakeTimers() |
| 204 | return ModelParallelConfig(tensor_model_parallel_size=config.get('tensor_model_parallel_size'), |
| 205 | pipeline_model_parallel_size=config.get('pipeline_model_parallel_size'), |
| 206 | virtual_pipeline_model_parallel_size=config.get('virtual_pipeline_model_parallel_size'), |
| 207 | sequence_parallel=config.get('sequence_parallel'), |
| 208 | params_dtype=PrecisionType.to_dtype(config.get('param_dtype')), |
| 209 | pipeline_dtype=PrecisionType.to_dtype(config.get('param_dtype')), |
| 210 | bf16=True, |
| 211 | fp16=False, |
| 212 | timers=timers) |
| 213 | |
| 214 | |
| 215 | class FakeTimers: |
no test coverage detected