Preprocess a single sample for wdsdataloader.
(
sample,
audio_ext,
text_ext,
max_len,
audio_cfg,
class_index_dict=None,
data_filling="pad",
data_truncating="rand_trunc",
text_augment_selection=None,
)
| 564 | |
| 565 | |
| 566 | def preprocess( |
| 567 | sample, |
| 568 | audio_ext, |
| 569 | text_ext, |
| 570 | max_len, |
| 571 | audio_cfg, |
| 572 | class_index_dict=None, |
| 573 | data_filling="pad", |
| 574 | data_truncating="rand_trunc", |
| 575 | text_augment_selection=None, |
| 576 | ): |
| 577 | """ |
| 578 | Preprocess a single sample for wdsdataloader. |
| 579 | """ |
| 580 | audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) |
| 581 | audio_data = int16_to_float32(float32_to_int16(audio_data)) |
| 582 | audio_data = torch.tensor(audio_data).float() |
| 583 | |
| 584 | # TODO: (yusong) to be include in the future |
| 585 | # # if torchaudio not installed, use soundfile to load audio |
| 586 | # if torchaudio is None: |
| 587 | # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) |
| 588 | # audio_data = torch.tensor(audio_data).float() |
| 589 | # else: |
| 590 | # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py |
| 591 | # with tempfile.TemporaryDirectory() as dirname: |
| 592 | # os.makedirs(dirname, exist_ok=True) |
| 593 | # fname = os.path.join(dirname, f"file.flac") |
| 594 | # with open(fname, "wb") as stream: |
| 595 | # stream.write(sample[audio_ext]) |
| 596 | # audio_data, orig_sr = torchaudio.load(fname) |
| 597 | # audio_data = audio_data[0, :].float() |
| 598 | |
| 599 | sample = get_audio_features( |
| 600 | sample, audio_data, max_len, data_truncating, data_filling, audio_cfg |
| 601 | ) |
| 602 | del sample[audio_ext] |
| 603 | |
| 604 | try: |
| 605 | json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) |
| 606 | except: |
| 607 | print("sample[__url__]:", sample["__url__"]) |
| 608 | |
| 609 | # For selecting augmented text from dataset |
| 610 | if text_augment_selection is None or text_augment_selection == "none": |
| 611 | texts = json_dict_raw["text"] |
| 612 | elif text_augment_selection == "all": |
| 613 | if "text_augment_all" in json_dict_raw.keys(): |
| 614 | texts = json_dict_raw["text_augment_all"] |
| 615 | else: |
| 616 | texts = json_dict_raw["text"] |
| 617 | elif text_augment_selection == "augment_only": |
| 618 | if "text_augment_all" in json_dict_raw.keys(): |
| 619 | if json_dict_raw["text_augment_t5"] is None: |
| 620 | texts = json_dict_raw["text"] |
| 621 | else: |
| 622 | texts = json_dict_raw["text_augment_t5"] |
| 623 | else: |
nothing calls this directly
no test coverage detected