(self)
| 113 | self.config = config |
| 114 | |
| 115 | def get_reader(self): |
| 116 | logger.info("Get DataLoader") |
| 117 | |
| 118 | config_abs_dir = self.config.get("config_abs_dir", None) |
| 119 | reader_path = self.config.get('runner.train_reader_path', 'reader') |
| 120 | reader_path = os.path.join(config_abs_dir, reader_path) |
| 121 | logger.info("Reader Path: {}".format(reader_path)) |
| 122 | |
| 123 | from paddle.io import DataLoader |
| 124 | dataset = common.lazy_instance_by_fliename(reader_path, "RecDataset") |
| 125 | print("dataset: {}".format(dataset)) |
| 126 | |
| 127 | use_cuda = int(self.config.get("runner.use_gpu")) |
| 128 | batch_size = self.config.get('runner.train_batch_size', None) |
| 129 | place = paddle.set_device('gpu' if use_cuda else 'cpu') |
| 130 | |
| 131 | generator = dataset(self.file_list, self.config) |
| 132 | generator.init() |
| 133 | loader = DataLoader( |
| 134 | generator, batch_size=batch_size, places=place, drop_last=True) |
| 135 | return loader |
| 136 | |
| 137 | |
| 138 | class DataLoader(object): |
no test coverage detected