(root_path, dataset, args)
| 58 | |
| 59 | |
| 60 | def load_data(root_path, dataset, args): |
| 61 | def collate(minibatch): |
| 62 | fbanks = [] |
| 63 | tokens = [] |
| 64 | for _, info in minibatch[0]: |
| 65 | fbanks.append( |
| 66 | torch.tensor( |
| 67 | kaldiio.load_mat( |
| 68 | info["input"][0]["feat"].replace( |
| 69 | data_config[dataset]["prefix"], root_path |
| 70 | ) |
| 71 | ) |
| 72 | ) |
| 73 | ) |
| 74 | tokens.append( |
| 75 | torch.tensor([int(s) for s in info["output"][0]["tokenid"].split()]) |
| 76 | ) |
| 77 | ilens = torch.tensor([x.shape[0] for x in fbanks]) |
| 78 | return ( |
| 79 | pad_sequence(fbanks, batch_first=True, padding_value=0), |
| 80 | ilens, |
| 81 | pad_sequence(tokens, batch_first=True, padding_value=-1), |
| 82 | ) |
| 83 | language = dataset |
| 84 | if language in low_resource_languages: |
| 85 | template_key = "template100" |
| 86 | else: |
| 87 | template_key = "template150" |
| 88 | data_config[dataset] = data_config[template_key].copy() |
| 89 | for key in ["train", "val", "test", "token"]: |
| 90 | data_config[dataset][key] = data_config[template_key][key].replace("template", dataset) |
| 91 | train_json = os.path.join(root_path, data_config[dataset]["train"]) |
| 92 | dev_json = ( |
| 93 | os.path.join(root_path, data_config[dataset]["val"]) |
| 94 | if data_config[dataset]["val"] |
| 95 | else f"{root_path}/tmp_dev_set_{dataset}.json" |
| 96 | ) |
| 97 | test_json = os.path.join(root_path, data_config[dataset]["test"]) |
| 98 | train_json, dev_json, test_json = load_json(train_json, dev_json, test_json) |
| 99 | _, info = next(iter(train_json.items())) |
| 100 | idim = info["input"][0]["shape"][1] |
| 101 | odim = info["output"][0]["shape"][1] |
| 102 | |
| 103 | use_sortagrad = False # args.sortagrad == -1 or args.sortagrad > 0 |
| 104 | # trainset = make_batchset(train_json, batch_size, max_length_in=800, max_length_out=150) |
| 105 | trainset = make_batchset( |
| 106 | train_json, |
| 107 | args.batch_size, |
| 108 | args.maxlen_in, |
| 109 | args.maxlen_out, |
| 110 | args.minibatches, |
| 111 | min_batch_size=args.ngpu if (args.ngpu > 1 and not args.dist_train) else 1, |
| 112 | shortest_first=use_sortagrad, |
| 113 | count=args.batch_count, |
| 114 | batch_bins=args.batch_bins, |
| 115 | batch_frames_in=args.batch_frames_in, |
| 116 | batch_frames_out=args.batch_frames_out, |
| 117 | batch_frames_inout=args.batch_frames_inout, |
nothing calls this directly
no test coverage detected