| 90 | |
| 91 | |
| 92 | def train(args): # noqa: C901 |
| 93 | if isinstance(args.dataset, list): |
| 94 | args.dataset = args.dataset[0] |
| 95 | if isinstance(args.model, list): |
| 96 | args.model = args.model[0] |
| 97 | if isinstance(args.seed, list): |
| 98 | args.seed = args.seed[0] |
| 99 | if isinstance(args.split, list): |
| 100 | args.split = args.split[0] |
| 101 | set_random_seed(args.seed) |
| 102 | |
| 103 | model_name = args.model if isinstance(args.model, str) else args.model.model_name |
| 104 | dw_name = args.dw if isinstance(args.dw, str) else args.dw.__name__ |
| 105 | mw_name = args.mw if isinstance(args.mw, str) else args.mw.__name__ |
| 106 | |
| 107 | print( |
| 108 | f""" |
| 109 | |-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}| |
| 110 | *** Running (`{args.dataset}`, `{model_name}`, `{dw_name}`, `{mw_name}`) |
| 111 | |-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|""" |
| 112 | ) |
| 113 | |
| 114 | if hasattr(args, "save_model_path"): |
| 115 | args = build_model_path(args, model_name) |
| 116 | |
| 117 | if getattr(args, "use_best_config", False): |
| 118 | args = set_best_config(args) |
| 119 | |
| 120 | # setup dataset and specify `num_features` and `num_classes` for model |
| 121 | if isinstance(args.dataset, Dataset): |
| 122 | dataset = args.dataset |
| 123 | else: |
| 124 | dataset = build_dataset(args) |
| 125 | |
| 126 | mw_class = fetch_model_wrapper(args.mw) |
| 127 | dw_class = fetch_data_wrapper(args.dw) |
| 128 | |
| 129 | if mw_class is None: |
| 130 | raise NotImplementedError("`model wrapper(--mw)` must be specified.") |
| 131 | |
| 132 | if dw_class is None: |
| 133 | raise NotImplementedError("`data wrapper(--dw)` must be specified.") |
| 134 | |
| 135 | data_wrapper_args = dict() |
| 136 | model_wrapper_args = dict() |
| 137 | # unworthy code: share `args` between model and dataset_wrapper |
| 138 | for key in inspect.signature(dw_class).parameters.keys(): |
| 139 | if hasattr(args, key) and key != "dataset": |
| 140 | data_wrapper_args[key] = getattr(args, key) |
| 141 | # unworthy code: share `args` between model and model_wrapper |
| 142 | for key in inspect.signature(mw_class).parameters.keys(): |
| 143 | if hasattr(args, key) and key != "model": |
| 144 | model_wrapper_args[key] = getattr(args, key) |
| 145 | |
| 146 | # setup data_wrapper |
| 147 | dataset_wrapper = dw_class(dataset, **data_wrapper_args) |
| 148 | |
| 149 | args.num_features = dataset.num_features |