| 57 | |
| 58 | |
| 59 | def test_train_val_split(model_data: RasaModelData): |
| 60 | train_model_data, test_model_data = model_data.split(2, 42) |
| 61 | |
| 62 | for key, values in model_data.items(): |
| 63 | assert len(values) == len(train_model_data.get(key)) |
| 64 | assert len(values) == len(test_model_data.get(key)) |
| 65 | for sub_key, data in values.items(): |
| 66 | assert len(data) == len(train_model_data.get(key, sub_key)) |
| 67 | assert len(data) == len(test_model_data.get(key, sub_key)) |
| 68 | for i, v in enumerate(data): |
| 69 | if isinstance(v[0], list): |
| 70 | assert ( |
| 71 | v[0][0].dtype |
| 72 | == train_model_data.get(key, sub_key)[i][0][0].dtype |
| 73 | ) |
| 74 | else: |
| 75 | assert v[0].dtype == train_model_data.get(key, sub_key)[i][0].dtype |
| 76 | |
| 77 | for values in train_model_data.values(): |
| 78 | for data in values.values(): |
| 79 | for v in data: |
| 80 | assert np.array(v).shape[0] == 3 |
| 81 | |
| 82 | for values in test_model_data.values(): |
| 83 | for data in values.values(): |
| 84 | for v in data: |
| 85 | assert np.array(v).shape[0] == 2 |
| 86 | |
| 87 | |
| 88 | @pytest.mark.parametrize("size", [0, 1, 5]) |