(root_path, datasets, args, languages)
| 182 | |
| 183 | |
| 184 | def load_multilingual_data(root_path, datasets, args, languages): |
| 185 | def collate(minibatch): |
| 186 | out = [] |
| 187 | for b in minibatch: |
| 188 | fbanks = [] |
| 189 | tokens = [] |
| 190 | language = None |
| 191 | for _, info in b: |
| 192 | fbanks.append( |
| 193 | torch.tensor( |
| 194 | kaldiio.load_mat( |
| 195 | info["input"][0]["feat"].replace( |
| 196 | data_config[dataset]["prefix"], root_path |
| 197 | ) |
| 198 | ) |
| 199 | ) |
| 200 | ) |
| 201 | tokens.append( |
| 202 | torch.tensor([int(s) for s in info["output"][0]["tokenid"].split()]) |
| 203 | ) |
| 204 | if language is not None: |
| 205 | assert language == info['category'] |
| 206 | else: |
| 207 | language = info['category'] |
| 208 | ilens = torch.tensor([x.shape[0] for x in fbanks]) |
| 209 | out.append(( |
| 210 | pad_sequence(fbanks, batch_first=True, padding_value=0), |
| 211 | ilens, |
| 212 | pad_sequence(tokens, batch_first=True, padding_value=-1), |
| 213 | language, |
| 214 | )) |
| 215 | return out[0] if len(out) == 1 else out |
| 216 | idim = None |
| 217 | odim_dict = {} |
| 218 | mtl_train_json, mtl_dev_json, mtl_test_json = {}, {}, {} |
| 219 | for idx, dataset in enumerate(datasets): |
| 220 | language = dataset |
| 221 | if language in low_resource_languages: |
| 222 | template_key = "template100" |
| 223 | else: |
| 224 | template_key = "template150" |
| 225 | data_config[dataset] = data_config[template_key].copy() |
| 226 | for key in ["train", "val", "test", "token"]: |
| 227 | data_config[dataset][key] = data_config[template_key][key].replace("template", dataset) |
| 228 | |
| 229 | train_json = os.path.join(root_path, data_config[dataset]["train"]) |
| 230 | dev_json = ( |
| 231 | os.path.join(root_path, data_config[dataset]["val"]) |
| 232 | if data_config[dataset]["val"] |
| 233 | else f"{root_path}/tmp_dev_set_{dataset}.json" |
| 234 | ) |
| 235 | test_json = os.path.join(root_path, data_config[dataset]["test"]) |
| 236 | train_json, dev_json, test_json = load_json(train_json, dev_json, test_json) |
| 237 | for key in train_json.keys(): |
| 238 | train_json[key]['category'] = language |
| 239 | for key in dev_json.keys(): |
| 240 | dev_json[key]['category'] = language |
| 241 | for key in test_json.keys(): |
nothing calls this directly
no test coverage detected