(uri_path: str = None)
| 32 | |
| 33 | |
| 34 | def train_multiseg(uri_path: str = None): |
| 35 | model = init_instance_by_config(CSI300_GBDT_TASK["model"]) |
| 36 | dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) |
| 37 | with R.start(experiment_name="workflow", uri=uri_path): |
| 38 | R.log_params(**flatten_dict(CSI300_GBDT_TASK)) |
| 39 | model.fit(dataset) |
| 40 | recorder = R.get_recorder() |
| 41 | sr = MultiSegRecord(model, dataset, recorder) |
| 42 | sr.generate(dict(valid="valid", test="test"), True) |
| 43 | uri = R.get_uri() |
| 44 | return uri |
| 45 | |
| 46 | |
| 47 | def train_mse(uri_path: str = None): |
no test coverage detected