| 6 | |
| 7 | |
| 8 | def get_train_ds_config(offload, stage=2, precision="fp16"): |
| 9 | |
| 10 | device = "cpu" if offload else "none" |
| 11 | zero_opt_dict = { |
| 12 | "stage": stage, |
| 13 | "offload_param": { |
| 14 | "device": device |
| 15 | }, |
| 16 | "offload_optimizer": { |
| 17 | "device": device |
| 18 | }, |
| 19 | "stage3_param_persistence_threshold": 1e4, |
| 20 | "stage3_max_live_parameters": 3e7, |
| 21 | "stage3_prefetch_bucket_size": 3e7, |
| 22 | } |
| 23 | ds_config = { |
| 24 | "train_batch_size": GLOBAL_BATCH_SIZE, |
| 25 | "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, |
| 26 | "steps_per_print": 1, |
| 27 | "zero_optimization": zero_opt_dict, |
| 28 | "gradient_clipping": 1.0, |
| 29 | "prescale_gradients": False, |
| 30 | "wall_clock_breakdown": False, |
| 31 | "checkpoint": { |
| 32 | "use_node_local_storage": True |
| 33 | } |
| 34 | |
| 35 | } |
| 36 | if precision == "fp16": |
| 37 | ds_config["fp16"] = { |
| 38 | "enabled": True, |
| 39 | "loss_scale": 0, |
| 40 | "loss_scale_window": 500, |
| 41 | "hysteresis": 2, |
| 42 | "min_loss_scale": 1, |
| 43 | "initial_scale_power":12 |
| 44 | } |
| 45 | elif precision == "bf16": |
| 46 | ds_config["bf16"] = {"enabled": True} |
| 47 | else: |
| 48 | raise ValueError("Mixed Precision type must be selected") |
| 49 | return ds_config |