MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / get_strategy

Function get_strategy

tools/utils/static_ps/program_helper.py:30–87  ·  view source on GitHub ↗
(config)

Source from the content-addressed store, hash-verified

28
29
30def get_strategy(config):
31 if not common.is_distributed_env():
32 logger.warn(
33 "Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
34 )
35 return None
36 sync_mode = config.get("runner.sync_mode")
37 assert sync_mode in ["async", "sync", "geo", "heter", "gpubox"]
38 if sync_mode == "sync":
39 strategy = paddle.distributed.fleet.DistributedStrategy()
40 strategy.a_sync = False
41 elif sync_mode == "async":
42 strategy = paddle.distributed.fleet.DistributedStrategy()
43 strategy.a_sync = True
44 elif sync_mode == "geo":
45 strategy = paddle.distributed.fleet.DistributedStrategy()
46 strategy.a_sync = True
47 strategy.a_sync_configs = {"k_steps": config.get("runner.geo_step")}
48 elif sync_mode == "heter":
49 strategy = paddle.distributed.fleet.DistributedStrategy()
50 strategy.a_sync = True
51 strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"}
52 elif sync_mode == "gpubox":
53 print("sync_mode = {}".format(sync_mode))
54 strategy = paddle.distributed.fleet.DistributedStrategy()
55 strategy.a_sync = True
56 strategy.a_sync_configs = {"use_ps_gpu": 1}
57
58 strategy.trainer_desc_configs = {
59 "dump_fields_path": config.get("runner.dump_fields_path", ""),
60 "dump_fields": config.get("runner.dump_fields", []),
61 "dump_param": config.get("runner.dump_param", []),
62 "stat_var_names": config.get("stat_var_names", [])
63 }
64 print("strategy:", strategy.trainer_desc_configs)
65
66 if config.get("runner.fs_client.uri") is not None:
67 strategy.fs_client_param = {
68 "uri": config.get("runner.fs_client.uri", ""),
69 "user": config.get("runner.fs_client.user", ""),
70 "passwd": config.get("runner.fs_client.passwd", ""),
71 "hadoop_bin": config.get("runner.fs_client.hadoop_bin", "hadoop")
72 }
73 print("strategy:", strategy.fs_client_param)
74
75 strategy.adam_d2sum = config.get("hyper_parameters.adam_d2sum", True)
76 table_config = {}
77 for x in config:
78 if x.startswith("table_parameters"):
79 table_name = x.split('.')[1]
80 if table_name not in table_config:
81 table_config[table_name] = {}
82 table_config[table_name][x] = config[x]
83 print("table_config:", table_config)
84 strategy.sparse_table_configs = table_config
85 print("strategy table config:", strategy.sparse_table_configs)
86
87 return strategy

Callers 5

init_networkMethod · 0.90
networkMethod · 0.90
networkMethod · 0.90
networkMethod · 0.90
networkMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected