(args)
| 44 | |
| 45 | |
| 46 | def main(args): |
| 47 | paddle.seed(12345) |
| 48 | |
| 49 | # load config |
| 50 | config = load_yaml(args.config_yaml) |
| 51 | config["yaml_path"] = args.config_yaml |
| 52 | config["config_abs_dir"] = args.abs_dir |
| 53 | # modify config from command |
| 54 | if args.opt: |
| 55 | for parameter in args.opt: |
| 56 | parameter = parameter.strip() |
| 57 | key, value = parameter.split("=") |
| 58 | if type(config.get(key)) is int: |
| 59 | value = int(value) |
| 60 | if type(config.get(key)) is float: |
| 61 | value = float(value) |
| 62 | if type(config.get(key)) is bool: |
| 63 | value = (True if value.lower() == "true" else False) |
| 64 | config[key] = value |
| 65 | # load static model class |
| 66 | static_model_class = load_static_model_class(config) |
| 67 | input_data = static_model_class.create_feeds() |
| 68 | input_data_names = [data.name for data in input_data] |
| 69 | |
| 70 | fetch_vars = static_model_class.net(input_data) |
| 71 | |
| 72 | #infer_target_var = model.infer_target_var |
| 73 | logger.info("cpu_num: {}".format(os.getenv("CPU_NUM"))) |
| 74 | |
| 75 | use_gpu = config.get("runner.use_gpu", True) |
| 76 | use_xpu = config.get("runner.use_xpu", False) |
| 77 | use_auc = config.get("runner.use_auc", False) |
| 78 | use_visual = config.get("runner.use_visual", False) |
| 79 | use_inference = config.get("runner.use_inference", False) |
| 80 | auc_num = config.get("runner.auc_num", 1) |
| 81 | train_data_dir = config.get("runner.train_data_dir", None) |
| 82 | epochs = config.get("runner.epochs", None) |
| 83 | print_interval = config.get("runner.print_interval", None) |
| 84 | model_save_path = config.get("runner.model_save_path", "model_output") |
| 85 | model_init_path = config.get("runner.model_init_path", None) |
| 86 | batch_size = config.get("runner.train_batch_size", None) |
| 87 | reader_type = config.get("runner.reader_type", "DataLoader") |
| 88 | use_fleet = config.get("runner.use_fleet", False) |
| 89 | use_save_data = config.get("runner.use_save_data", False) |
| 90 | os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1)) |
| 91 | logger.info("**************common.configs**********") |
| 92 | logger.info( |
| 93 | "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}". |
| 94 | format(use_gpu, use_xpu, use_visual, batch_size, train_data_dir, |
| 95 | epochs, print_interval, model_save_path)) |
| 96 | logger.info("**************common.configs**********") |
| 97 | |
| 98 | if use_xpu: |
| 99 | xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) |
| 100 | place = paddle.set_device(xpu_device) |
| 101 | else: |
| 102 | place = paddle.set_device('gpu' if use_gpu else 'cpu') |
| 103 |
no test coverage detected