(config)
| 28 | |
| 29 | |
| 30 | def get_strategy(config): |
| 31 | if not common.is_distributed_env(): |
| 32 | logger.warn( |
| 33 | "Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command." |
| 34 | ) |
| 35 | return None |
| 36 | sync_mode = config.get("runner.sync_mode") |
| 37 | assert sync_mode in ["async", "sync", "geo", "heter", "gpubox"] |
| 38 | if sync_mode == "sync": |
| 39 | strategy = paddle.distributed.fleet.DistributedStrategy() |
| 40 | strategy.a_sync = False |
| 41 | elif sync_mode == "async": |
| 42 | strategy = paddle.distributed.fleet.DistributedStrategy() |
| 43 | strategy.a_sync = True |
| 44 | elif sync_mode == "geo": |
| 45 | strategy = paddle.distributed.fleet.DistributedStrategy() |
| 46 | strategy.a_sync = True |
| 47 | strategy.a_sync_configs = {"k_steps": config.get("runner.geo_step")} |
| 48 | elif sync_mode == "heter": |
| 49 | strategy = paddle.distributed.fleet.DistributedStrategy() |
| 50 | strategy.a_sync = True |
| 51 | strategy.a_sync_configs = {"heter_worker_device_guard": "gpu"} |
| 52 | elif sync_mode == "gpubox": |
| 53 | print("sync_mode = {}".format(sync_mode)) |
| 54 | strategy = paddle.distributed.fleet.DistributedStrategy() |
| 55 | strategy.a_sync = True |
| 56 | strategy.a_sync_configs = {"use_ps_gpu": 1} |
| 57 | |
| 58 | strategy.trainer_desc_configs = { |
| 59 | "dump_fields_path": config.get("runner.dump_fields_path", ""), |
| 60 | "dump_fields": config.get("runner.dump_fields", []), |
| 61 | "dump_param": config.get("runner.dump_param", []), |
| 62 | "stat_var_names": config.get("stat_var_names", []) |
| 63 | } |
| 64 | print("strategy:", strategy.trainer_desc_configs) |
| 65 | |
| 66 | if config.get("runner.fs_client.uri") is not None: |
| 67 | strategy.fs_client_param = { |
| 68 | "uri": config.get("runner.fs_client.uri", ""), |
| 69 | "user": config.get("runner.fs_client.user", ""), |
| 70 | "passwd": config.get("runner.fs_client.passwd", ""), |
| 71 | "hadoop_bin": config.get("runner.fs_client.hadoop_bin", "hadoop") |
| 72 | } |
| 73 | print("strategy:", strategy.fs_client_param) |
| 74 | |
| 75 | strategy.adam_d2sum = config.get("hyper_parameters.adam_d2sum", True) |
| 76 | table_config = {} |
| 77 | for x in config: |
| 78 | if x.startswith("table_parameters"): |
| 79 | table_name = x.split('.')[1] |
| 80 | if table_name not in table_config: |
| 81 | table_config[table_name] = {} |
| 82 | table_config[table_name][x] = config[x] |
| 83 | print("table_config:", table_config) |
| 84 | strategy.sparse_table_configs = table_config |
| 85 | print("strategy table config:", strategy.sparse_table_configs) |
| 86 | |
| 87 | return strategy |
no outgoing calls
no test coverage detected