| 21 | |
| 22 | |
| 23 | class AudioDataset(torch.utils.data.Dataset): |
| 24 | |
| 25 | def __init__(self, ds, prompt): |
| 26 | path = ds['path'] |
| 27 | self.datas = open(path).readlines() |
| 28 | self.prompt = prompt |
| 29 | |
| 30 | def __len__(self): |
| 31 | return len(self.datas) |
| 32 | |
| 33 | def __getitem__(self, idx): |
| 34 | data = json.loads(self.datas[idx].strip()) |
| 35 | audio_path = data['audio'] |
| 36 | source = data['source'] |
| 37 | gt = data['gt'] |
| 38 | |
| 39 | return { |
| 40 | 'input_text': self.prompt.format(audio_path), |
| 41 | 'audio_path': audio_path, |
| 42 | 'source': source, |
| 43 | 'gt': gt |
| 44 | } |
| 45 | |
| 46 | |
| 47 | def collate_fn(inputs, tokenizer): |