MCPcopy Index your code
hub / github.com/jindongwang/transferlearning / load_data

Function load_data

code/ASR/Adapter/data_load.py:60–181  ·  view source on GitHub ↗
(root_path, dataset, args)

Source from the content-addressed store, hash-verified

58
59
60def load_data(root_path, dataset, args):
61 def collate(minibatch):
62 fbanks = []
63 tokens = []
64 for _, info in minibatch[0]:
65 fbanks.append(
66 torch.tensor(
67 kaldiio.load_mat(
68 info["input"][0]["feat"].replace(
69 data_config[dataset]["prefix"], root_path
70 )
71 )
72 )
73 )
74 tokens.append(
75 torch.tensor([int(s) for s in info["output"][0]["tokenid"].split()])
76 )
77 ilens = torch.tensor([x.shape[0] for x in fbanks])
78 return (
79 pad_sequence(fbanks, batch_first=True, padding_value=0),
80 ilens,
81 pad_sequence(tokens, batch_first=True, padding_value=-1),
82 )
83 language = dataset
84 if language in low_resource_languages:
85 template_key = "template100"
86 else:
87 template_key = "template150"
88 data_config[dataset] = data_config[template_key].copy()
89 for key in ["train", "val", "test", "token"]:
90 data_config[dataset][key] = data_config[template_key][key].replace("template", dataset)
91 train_json = os.path.join(root_path, data_config[dataset]["train"])
92 dev_json = (
93 os.path.join(root_path, data_config[dataset]["val"])
94 if data_config[dataset]["val"]
95 else f"{root_path}/tmp_dev_set_{dataset}.json"
96 )
97 test_json = os.path.join(root_path, data_config[dataset]["test"])
98 train_json, dev_json, test_json = load_json(train_json, dev_json, test_json)
99 _, info = next(iter(train_json.items()))
100 idim = info["input"][0]["shape"][1]
101 odim = info["output"][0]["shape"][1]
102
103 use_sortagrad = False # args.sortagrad == -1 or args.sortagrad > 0
104 # trainset = make_batchset(train_json, batch_size, max_length_in=800, max_length_out=150)
105 trainset = make_batchset(
106 train_json,
107 args.batch_size,
108 args.maxlen_in,
109 args.maxlen_out,
110 args.minibatches,
111 min_batch_size=args.ngpu if (args.ngpu > 1 and not args.dist_train) else 1,
112 shortest_first=use_sortagrad,
113 count=args.batch_count,
114 batch_bins=args.batch_bins,
115 batch_frames_in=args.batch_frames_in,
116 batch_frames_out=args.batch_frames_out,
117 batch_frames_inout=args.batch_frames_inout,

Callers

nothing calls this directly

Calls 1

load_jsonFunction · 0.70

Tested by

no test coverage detected