| 889 | jsonl_file = None |
| 890 | |
| 891 | def open_new_shard(): |
| 892 | nonlocal tar_writer, jsonl_file, shard_idx, shard_sample_count, shard_duration |
| 893 | if tar_writer is not None: |
| 894 | tar_writer.close() |
| 895 | if jsonl_file is not None: |
| 896 | jsonl_file.close() |
| 897 | if shard_idx > 0 and shard_sample_count > 0: |
| 898 | prev_idx = shard_idx - 1 |
| 899 | shard_manifest[prev_idx] = ( |
| 900 | os.path.abspath(tar_output_pattern % prev_idx), |
| 901 | os.path.abspath(jsonl_output_pattern % prev_idx), |
| 902 | shard_sample_count, |
| 903 | shard_duration, |
| 904 | ) |
| 905 | tar_fname = tar_output_pattern % shard_idx |
| 906 | jsonl_fname = jsonl_output_pattern % shard_idx |
| 907 | tar_writer = wds.TarWriter(tar_fname) |
| 908 | jsonl_file = open(jsonl_fname, "w", encoding="utf-8") |
| 909 | shard_idx += 1 |
| 910 | shard_sample_count = 0 |
| 911 | shard_duration = 0.0 |
| 912 | |
| 913 | def write_sample(key, waveform, metadata): |
| 914 | nonlocal shard_sample_count, write_error_count, shard_duration |