MCPcopy
hub / github.com/jindongwang/transferlearning / load_multilingual_data

Function load_multilingual_data

code/ASR/Adapter/data_load.py:184–346  ·  view source on GitHub ↗
(root_path, datasets, args, languages)

Source from the content-addressed store, hash-verified

182
183
184def load_multilingual_data(root_path, datasets, args, languages):
185 def collate(minibatch):
186 out = []
187 for b in minibatch:
188 fbanks = []
189 tokens = []
190 language = None
191 for _, info in b:
192 fbanks.append(
193 torch.tensor(
194 kaldiio.load_mat(
195 info["input"][0]["feat"].replace(
196 data_config[dataset]["prefix"], root_path
197 )
198 )
199 )
200 )
201 tokens.append(
202 torch.tensor([int(s) for s in info["output"][0]["tokenid"].split()])
203 )
204 if language is not None:
205 assert language == info['category']
206 else:
207 language = info['category']
208 ilens = torch.tensor([x.shape[0] for x in fbanks])
209 out.append((
210 pad_sequence(fbanks, batch_first=True, padding_value=0),
211 ilens,
212 pad_sequence(tokens, batch_first=True, padding_value=-1),
213 language,
214 ))
215 return out[0] if len(out) == 1 else out
216 idim = None
217 odim_dict = {}
218 mtl_train_json, mtl_dev_json, mtl_test_json = {}, {}, {}
219 for idx, dataset in enumerate(datasets):
220 language = dataset
221 if language in low_resource_languages:
222 template_key = "template100"
223 else:
224 template_key = "template150"
225 data_config[dataset] = data_config[template_key].copy()
226 for key in ["train", "val", "test", "token"]:
227 data_config[dataset][key] = data_config[template_key][key].replace("template", dataset)
228
229 train_json = os.path.join(root_path, data_config[dataset]["train"])
230 dev_json = (
231 os.path.join(root_path, data_config[dataset]["val"])
232 if data_config[dataset]["val"]
233 else f"{root_path}/tmp_dev_set_{dataset}.json"
234 )
235 test_json = os.path.join(root_path, data_config[dataset]["test"])
236 train_json, dev_json, test_json = load_json(train_json, dev_json, test_json)
237 for key in train_json.keys():
238 train_json[key]['category'] = language
239 for key in dev_json.keys():
240 dev_json[key]['category'] = language
241 for key in test_json.keys():

Callers

nothing calls this directly

Calls 3

load_jsonFunction · 0.70
updateMethod · 0.45

Tested by

no test coverage detected