(args)
| 39 | |
| 40 | |
| 41 | def train(args): # noqa: C901 |
| 42 | if isinstance(args.dataset, list): |
| 43 | args.dataset = args.dataset[0] |
| 44 | if isinstance(args.model, list): |
| 45 | args.model = args.model[0] |
| 46 | if isinstance(args.seed, list): |
| 47 | args.seed = args.seed[0] |
| 48 | if isinstance(args.split, list): |
| 49 | args.split = args.split[0] |
| 50 | # dataset='cora', model='gcn', seed=1, split=0 |
| 51 | set_random_seed(args.seed) |
| 52 | |
| 53 | |
| 54 | model_name = args.model if isinstance(args.model, str) else args.model.model_name |
| 55 | dw_name = args.dw if isinstance(args.dw, str) else args.dw.__name__ |
| 56 | mw_name = args.mw if isinstance(args.mw, str) else args.mw.__name__ |
| 57 | |
| 58 | print( |
| 59 | f""" |
| 60 | |-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}| |
| 61 | *** Running (`{args.dataset}`, `{model_name}`, `{dw_name}`, `{mw_name}`) |
| 62 | |-------------------------------------{'-' * ( |
| 63 | len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|""" |
| 64 | ) |
| 65 | |
| 66 | |
| 67 | dataset = build_dataset(args) |
| 68 | mw_class = fetch_model_wrapper(args.mw) |
| 69 | dw_class = fetch_data_wrapper(args.dw) |
| 70 | |
| 71 | if mw_class is None: |
| 72 | raise NotImplementedError("`model wrapper(--mw)` must be specified.") |
| 73 | |
| 74 | if dw_class is None: |
| 75 | raise NotImplementedError("`data wrapper(--dw)` must be specified.") |
| 76 | |
| 77 | data_wrapper_args = dict() |
| 78 | model_wrapper_args = dict() |
| 79 | |
| 80 | data_wrapper_args['batch_size'] = args.batch_size |
| 81 | data_wrapper_args['n_his'] = args.n_his |
| 82 | data_wrapper_args['n_pred'] = args.n_pred |
| 83 | data_wrapper_args['train_prop'] = args.train_prop |
| 84 | data_wrapper_args['val_prop'] = args.val_prop |
| 85 | data_wrapper_args['test_prop'] = args.test_prop |
| 86 | data_wrapper_args['pred_length'] = args.pred_length |
| 87 | |
| 88 | dataset_wrapper = dw_class(dataset, **data_wrapper_args) |
| 89 | |
| 90 | args.num_features = dataset.num_features |
| 91 | if hasattr(dataset, "num_nodes"): |
| 92 | args.num_nodes = dataset.num_nodes |
| 93 | if hasattr(dataset, "num_edges"): |
| 94 | args.num_edges = dataset.num_edges |
| 95 | if hasattr(dataset, "num_edge"): |
| 96 | args.num_edge = dataset.num_edge |
| 97 | if hasattr(dataset, "max_graph_size"): |
| 98 | args.max_graph_size = dataset.max_graph_size |
no test coverage detected