JSON dataset, where each line of the data file contains a JSON dictionary with text fields.
| 331 | |
| 332 | |
| 333 | class JsonDataset(object): |
| 334 | """ JSON dataset, where each line of the data file contains a JSON |
| 335 | dictionary with text fields. |
| 336 | """ |
| 337 | |
| 338 | @staticmethod |
| 339 | def get_default_config(updates=None): |
| 340 | config = ConfigDict() |
| 341 | config.path = '' |
| 342 | config.seq_length = 1024 |
| 343 | config.batch_size = 8 |
| 344 | config.always_start_with_bos = False |
| 345 | config.start_seek_loc = 0 |
| 346 | config.example_index_at_start = 0 |
| 347 | config.tokens_count_at_start = 0 |
| 348 | config.tokenizer_processes = 1 |
| 349 | config.tokenizer_parallel_chunk_size = 32 |
| 350 | config.tokenizer_parallel_batch_size = 1024 |
| 351 | config.throughput_average_window_size = 200 |
| 352 | config.pad = False |
| 353 | config.use_data_sharded_loader = True |
| 354 | config.return_local_batch = False |
| 355 | |
| 356 | if updates is not None: |
| 357 | config.update(ConfigDict(updates).copy_and_resolve_references()) |
| 358 | return config |
| 359 | |
| 360 | def __init__(self, config, tokenizer, text_processor, node_info): |
| 361 | self.config = self.get_default_config(config) |
| 362 | assert self.config.path != '' |
| 363 | self._tokenizer = tokenizer |
| 364 | self._text_processor = text_processor |
| 365 | self._node_info = node_info |
| 366 | self._index = self.config.example_index_at_start |
| 367 | self._file_loc = self.config.start_seek_loc |
| 368 | self._total_tokens = self.config.tokens_count_at_start |
| 369 | |
| 370 | def parse_json(self, line): |
| 371 | if not line or line == '\n': |
| 372 | return None |
| 373 | try: |
| 374 | data = json.loads(line) |
| 375 | except json.decoder.JSONDecodeError: |
| 376 | print(f'Error parsing json line:\n{line}') |
| 377 | return None |
| 378 | return data |
| 379 | |
| 380 | def json_iterator(self): |
| 381 | index, file_loc = self._index, self._file_loc |
| 382 | with open_file(self.config.path, 'r') as fin: |
| 383 | fin.seek(file_loc) |
| 384 | while True: |
| 385 | line = fin.readline() |
| 386 | file_loc = fin.tell() |
| 387 | if not line: # Reached EOF |
| 388 | index = 0 |
| 389 | fin.seek(0) |
| 390 | continue |