| 67 | |
| 68 | |
| 69 | class Main(object): |
| 70 | def __init__(self, config): |
| 71 | self.metrics = {} |
| 72 | self.config = config |
| 73 | self.input_data = None |
| 74 | self.reader = None |
| 75 | self.exe = None |
| 76 | self.train_result_dict = {} |
| 77 | self.train_result_dict["speed"] = [] |
| 78 | self.model = None |
| 79 | self.pure_bf16 = self.config['pure_bf16'] |
| 80 | |
| 81 | def run(self): |
| 82 | self.init_fleet_with_gloo() |
| 83 | self.network() |
| 84 | if fleet.is_server(): |
| 85 | self.run_server() |
| 86 | elif fleet.is_worker(): |
| 87 | self.run_worker() |
| 88 | fleet.stop_worker() |
| 89 | self.record_result() |
| 90 | logger.info("Run Success, Exit.") |
| 91 | |
| 92 | def init_fleet_with_gloo(use_gloo=True): |
| 93 | if use_gloo: |
| 94 | os.environ["PADDLE_WITH_GLOO"] = "1" |
| 95 | role = role_maker.PaddleCloudRoleMaker() |
| 96 | fleet.init(role) |
| 97 | else: |
| 98 | fleet.init() |
| 99 | |
| 100 | def network(self): |
| 101 | self.model = get_model(self.config) |
| 102 | self.input_data = self.model.create_feeds() |
| 103 | self.inference_feed_var = self.model.create_feeds(is_infer=False) |
| 104 | self.init_reader() |
| 105 | self.metrics = self.model.net(self.input_data) |
| 106 | self.inference_target_var = self.model.inference_target_var |
| 107 | logger.info("cpu_num: {}".format(os.getenv("CPU_NUM"))) |
| 108 | self.model.create_optimizer(get_strategy(self.config)) |
| 109 | |
| 110 | def run_server(self): |
| 111 | logger.info("Run Server Begin") |
| 112 | fleet.init_server(config.get("runner.warmup_model_path")) |
| 113 | fleet.run_server() |
| 114 | |
| 115 | def run_worker(self): |
| 116 | logger.info("Run Worker Begin") |
| 117 | use_cuda = int(config.get("runner.use_gpu")) |
| 118 | place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() |
| 119 | self.exe = paddle.static.Executor(place) |
| 120 | |
| 121 | with open("./{}_worker_main_program.prototxt".format( |
| 122 | fleet.worker_index()), 'w+') as f: |
| 123 | f.write(str(paddle.static.default_main_program())) |
| 124 | with open("./{}_worker_startup_program.prototxt".format( |
| 125 | fleet.worker_index()), 'w+') as f: |
| 126 | f.write(str(paddle.static.default_startup_program())) |