(dataloader, model, args, model_path=None, wer=False, write_to_json=False)
| 202 | ) # sclite does not consider the number of spaces when splitting |
| 203 | return text |
| 204 | def recognize_and_evaluate(dataloader, model, args, model_path=None, wer=False, write_to_json=False): |
| 205 | if model_path: |
| 206 | torch_load(model_path, model) |
| 207 | orig_model = model |
| 208 | if hasattr(model, "module"): |
| 209 | model = model.module |
| 210 | if write_to_json: |
| 211 | # read json data |
| 212 | assert args.result_label and args.recog_json |
| 213 | with open(args.recog_json, "rb") as f: |
| 214 | js = json.load(f)["utts"] |
| 215 | new_js = {} |
| 216 | model.eval() |
| 217 | recog_args = { |
| 218 | "beam_size": args.beam_size, |
| 219 | "penalty": args.penalty, |
| 220 | "ctc_weight": args.ctc_weight, |
| 221 | "maxlenratio": args.maxlenratio, |
| 222 | "minlenratio": args.minlenratio, |
| 223 | "lm_weight": args.lm_weight, |
| 224 | "rnnlm": args.rnnlm, |
| 225 | "nbest": args.nbest, |
| 226 | "space": args.sym_space, |
| 227 | "blank": args.sym_blank, |
| 228 | } |
| 229 | recog_args = argparse.Namespace(**recog_args) |
| 230 | |
| 231 | #progress_bar = tqdm(dataloader) |
| 232 | #progress_bar.set_description("Testing CER/WERs") |
| 233 | err_dict = ( |
| 234 | dict(cer=None) |
| 235 | if not wer |
| 236 | else dict(cer=collections.defaultdict(int), wer=collections.defaultdict(int)) |
| 237 | ) |
| 238 | with torch.no_grad(): |
| 239 | for batch_idx, data in enumerate(dataloader): |
| 240 | logging.warning(f"Testing CER/WERs: {batch_idx+1}/{len(dataloader)}") |
| 241 | fbank, ilens, tokens = data |
| 242 | fbanks = [] |
| 243 | for i, fb in enumerate(fbank): |
| 244 | fbanks.append(fb[: ilens[i], :]) |
| 245 | fbank = fbanks |
| 246 | nbest_hyps = model.recognize_batch( |
| 247 | fbank, recog_args, char_list=None, rnnlm=None |
| 248 | ) |
| 249 | y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] |
| 250 | if write_to_json: |
| 251 | for utt_idx in range(len(fbank)): |
| 252 | name = dataloader.dataset[batch_idx][utt_idx][0] |
| 253 | new_js[name] = add_results_to_json( |
| 254 | js[name], nbest_hyps[utt_idx], args.char_list |
| 255 | ) |
| 256 | for i, y_hat in enumerate(y_hats): |
| 257 | y_true = tokens[i] |
| 258 | |
| 259 | hyp_token = [ |
| 260 | args.char_list[int(idx)] for idx in y_hat if int(idx) != -1 |
| 261 | ] |
no test coverage detected