| 86 | |
| 87 | |
| 88 | def create_data_loader(config, place, mode="train"): |
| 89 | if mode == "train": |
| 90 | data_dir = config.get("runner.train_data_dir", None) |
| 91 | batch_size = config.get('runner.train_batch_size', None) |
| 92 | reader_path = config.get('runner.train_reader_path', 'reader') |
| 93 | else: |
| 94 | data_dir = config.get("runner.test_data_dir", None) |
| 95 | batch_size = config.get('runner.infer_batch_size', None) |
| 96 | reader_path = config.get('runner.infer_reader_path', 'reader') |
| 97 | config_abs_dir = config.get("config_abs_dir", None) |
| 98 | data_dir = os.path.join(config_abs_dir, data_dir) |
| 99 | file_list = [os.path.join(data_dir, x) for x in os.listdir(data_dir)] |
| 100 | user_define_reader = config.get('runner.user_define_reader', False) |
| 101 | logger.info("reader path:{}".format(reader_path)) |
| 102 | from importlib import import_module |
| 103 | reader_class = import_module(reader_path) |
| 104 | dataset = reader_class.RecDataset(file_list, config=config) |
| 105 | loader = DataLoader( |
| 106 | dataset, batch_size=batch_size, places=place, drop_last=True) |
| 107 | return loader |
| 108 | |
| 109 | |
| 110 | def load_dy_model_class(abs_dir): |