| 660 | jsonl_file = None |
| 661 | |
| 662 | def open_new_shard(): |
| 663 | nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration |
| 664 | if tar_writer is not None: |
| 665 | tar_writer.close() |
| 666 | if jsonl_file is not None: |
| 667 | jsonl_file.close() |
| 668 | # Record manifest for the previous shard |
| 669 | if shard_idx > 0 and shard_sample_count > 0: |
| 670 | prev_idx = shard_idx - 1 |
| 671 | shard_manifest[prev_idx] = ( |
| 672 | os.path.abspath(tar_output_pattern % prev_idx), |
| 673 | os.path.abspath(jsonl_output_pattern % prev_idx), |
| 674 | shard_sample_count, |
| 675 | shard_duration, |
| 676 | ) |
| 677 | tar_fname = tar_output_pattern % shard_idx |
| 678 | jsonl_fname = jsonl_output_pattern % shard_idx |
| 679 | tar_writer = wds.TarWriter(tar_fname) |
| 680 | jsonl_file = open(jsonl_fname, "w", encoding="utf-8") |
| 681 | shard_idx += 1 |
| 682 | shard_sample_count = 0 |
| 683 | shard_duration = 0.0 |
| 684 | |
| 685 | def write_sample(key, audio_tokens_np, metadata): |
| 686 | nonlocal shard_sample_count, write_error_count, shard_duration |