MCPcopy
hub / github.com/Tele-AI/Telechat / get_train_ds_config

Function get_train_ds_config

deepspeed-telechat/utils/ds_utils.py:8–49  ·  view source on GitHub ↗
(offload, stage=2, precision="fp16")

Source from the content-addressed store, hash-verified

6
7
8def get_train_ds_config(offload, stage=2, precision="fp16"):
9
10 device = "cpu" if offload else "none"
11 zero_opt_dict = {
12 "stage": stage,
13 "offload_param": {
14 "device": device
15 },
16 "offload_optimizer": {
17 "device": device
18 },
19 "stage3_param_persistence_threshold": 1e4,
20 "stage3_max_live_parameters": 3e7,
21 "stage3_prefetch_bucket_size": 3e7,
22 }
23 ds_config = {
24 "train_batch_size": GLOBAL_BATCH_SIZE,
25 "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
26 "steps_per_print": 1,
27 "zero_optimization": zero_opt_dict,
28 "gradient_clipping": 1.0,
29 "prescale_gradients": False,
30 "wall_clock_breakdown": False,
31 "checkpoint": {
32 "use_node_local_storage": True
33 }
34
35 }
36 if precision == "fp16":
37 ds_config["fp16"] = {
38 "enabled": True,
39 "loss_scale": 0,
40 "loss_scale_window": 500,
41 "hysteresis": 2,
42 "min_loss_scale": 1,
43 "initial_scale_power":12
44 }
45 elif precision == "bf16":
46 ds_config["bf16"] = {"enabled": True}
47 else:
48 raise ValueError("Mixed Precision type must be selected")
49 return ds_config

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected