(config, place, mode="train")
| 87 | |
| 88 | |
| 89 | def create_data_loader(config, place, mode="train"): |
| 90 | if mode == "train": |
| 91 | data_dir = config.get("runner.train_data_dir", None) |
| 92 | batch_size = config.get('runner.train_batch_size', None) |
| 93 | reader_path = config.get('runner.train_reader_path', 'reader') |
| 94 | else: |
| 95 | data_dir = config.get("runner.test_data_dir", None) |
| 96 | batch_size = config.get('runner.infer_batch_size', None) |
| 97 | reader_path = config.get('runner.infer_reader_path', 'reader') |
| 98 | num_workers = config.get('runner.num_workers', 0) |
| 99 | config_abs_dir = config.get("config_abs_dir", None) |
| 100 | data_dir = os.path.join(config_abs_dir, data_dir) |
| 101 | file_list = [os.path.join(data_dir, x) for x in os.listdir(data_dir)] |
| 102 | user_define_reader = config.get('runner.user_define_reader', False) |
| 103 | logger.info("reader path:{}".format(reader_path)) |
| 104 | from importlib import import_module |
| 105 | reader_class = import_module(reader_path) |
| 106 | dataset = reader_class.RecDataset(file_list, config=config) |
| 107 | loader = DataLoader( |
| 108 | dataset, |
| 109 | batch_size=batch_size, |
| 110 | places=place, |
| 111 | drop_last=True, |
| 112 | num_workers=num_workers) |
| 113 | return loader |
| 114 | |
| 115 | |
| 116 | def load_dy_model_class(abs_dir): |
no test coverage detected