(train_json_file, dev_json_file, test_json_file)
| 35 | contents = json.load(f)["utts"] |
| 36 | return contents |
| 37 | def load_json(train_json_file, dev_json_file, test_json_file): |
| 38 | train_json = read_json_file(train_json_file) |
| 39 | if os.path.isfile(dev_json_file): |
| 40 | dev_json = read_json_file(dev_json_file) |
| 41 | else: |
| 42 | n_samples = len(train_json) |
| 43 | train_size = int(0.9 * n_samples) |
| 44 | logging.warning( |
| 45 | f"No dev set provided, will split the last {n_samples - train_size} (10%) samples from training data" |
| 46 | ) |
| 47 | train_json_item = list(train_json.items()) |
| 48 | # random.shuffle(train_json_item) |
| 49 | train_json = dict(train_json_item[:train_size]) |
| 50 | dev_json = dict(train_json_item[train_size:]) |
| 51 | |
| 52 | # Save temp dev set |
| 53 | with open(dev_json_file, "w") as f: |
| 54 | json.dump({"utts": dev_json}, f) |
| 55 | logging.warning(f"Temporary dev set saved: {dev_json_file}") |
| 56 | test_json = read_json_file(test_json_file) |
| 57 | return train_json, dev_json, test_json |
| 58 | |
| 59 | |
| 60 | def load_data(root_path, dataset, args): |
no test coverage detected