MCPcopy Index your code
hub / github.com/jindongwang/transferlearning / recognize_and_evaluate

Function recognize_and_evaluate

code/ASR/Adapter/utils.py:204–297  ·  view source on GitHub ↗
(dataloader, model, args, model_path=None, wer=False, write_to_json=False)

Source from the content-addressed store, hash-verified

202 ) # sclite does not consider the number of spaces when splitting
203 return text
204def recognize_and_evaluate(dataloader, model, args, model_path=None, wer=False, write_to_json=False):
205 if model_path:
206 torch_load(model_path, model)
207 orig_model = model
208 if hasattr(model, "module"):
209 model = model.module
210 if write_to_json:
211 # read json data
212 assert args.result_label and args.recog_json
213 with open(args.recog_json, "rb") as f:
214 js = json.load(f)["utts"]
215 new_js = {}
216 model.eval()
217 recog_args = {
218 "beam_size": args.beam_size,
219 "penalty": args.penalty,
220 "ctc_weight": args.ctc_weight,
221 "maxlenratio": args.maxlenratio,
222 "minlenratio": args.minlenratio,
223 "lm_weight": args.lm_weight,
224 "rnnlm": args.rnnlm,
225 "nbest": args.nbest,
226 "space": args.sym_space,
227 "blank": args.sym_blank,
228 }
229 recog_args = argparse.Namespace(**recog_args)
230
231 #progress_bar = tqdm(dataloader)
232 #progress_bar.set_description("Testing CER/WERs")
233 err_dict = (
234 dict(cer=None)
235 if not wer
236 else dict(cer=collections.defaultdict(int), wer=collections.defaultdict(int))
237 )
238 with torch.no_grad():
239 for batch_idx, data in enumerate(dataloader):
240 logging.warning(f"Testing CER/WERs: {batch_idx+1}/{len(dataloader)}")
241 fbank, ilens, tokens = data
242 fbanks = []
243 for i, fb in enumerate(fbank):
244 fbanks.append(fb[: ilens[i], :])
245 fbank = fbanks
246 nbest_hyps = model.recognize_batch(
247 fbank, recog_args, char_list=None, rnnlm=None
248 )
249 y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps]
250 if write_to_json:
251 for utt_idx in range(len(fbank)):
252 name = dataloader.dataset[batch_idx][utt_idx][0]
253 new_js[name] = add_results_to_json(
254 js[name], nbest_hyps[utt_idx], args.char_list
255 )
256 for i, y_hat in enumerate(y_hats):
257 y_true = tokens[i]
258
259 hyp_token = [
260 args.char_list[int(idx)] for idx in y_hat if int(idx) != -1
261 ]

Callers 1

train.pyFile · 0.90

Calls 7

token2textFunction · 0.85
recognize_batchMethod · 0.80
encodeMethod · 0.80
torch_loadFunction · 0.70
compute_werFunction · 0.70
loadMethod · 0.45
writeMethod · 0.45

Tested by

no test coverage detected