MCPcopy
hub / github.com/THUDM/CogDL / train

Function train

cogdl/experiments.py:92–230  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

90
91
92def train(args): # noqa: C901
93 if isinstance(args.dataset, list):
94 args.dataset = args.dataset[0]
95 if isinstance(args.model, list):
96 args.model = args.model[0]
97 if isinstance(args.seed, list):
98 args.seed = args.seed[0]
99 if isinstance(args.split, list):
100 args.split = args.split[0]
101 set_random_seed(args.seed)
102
103 model_name = args.model if isinstance(args.model, str) else args.model.model_name
104 dw_name = args.dw if isinstance(args.dw, str) else args.dw.__name__
105 mw_name = args.mw if isinstance(args.mw, str) else args.mw.__name__
106
107 print(
108 f"""
109|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|
110 *** Running (`{args.dataset}`, `{model_name}`, `{dw_name}`, `{mw_name}`)
111|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|"""
112 )
113
114 if hasattr(args, "save_model_path"):
115 args = build_model_path(args, model_name)
116
117 if getattr(args, "use_best_config", False):
118 args = set_best_config(args)
119
120 # setup dataset and specify `num_features` and `num_classes` for model
121 if isinstance(args.dataset, Dataset):
122 dataset = args.dataset
123 else:
124 dataset = build_dataset(args)
125
126 mw_class = fetch_model_wrapper(args.mw)
127 dw_class = fetch_data_wrapper(args.dw)
128
129 if mw_class is None:
130 raise NotImplementedError("`model wrapper(--mw)` must be specified.")
131
132 if dw_class is None:
133 raise NotImplementedError("`data wrapper(--dw)` must be specified.")
134
135 data_wrapper_args = dict()
136 model_wrapper_args = dict()
137 # unworthy code: share `args` between model and dataset_wrapper
138 for key in inspect.signature(dw_class).parameters.keys():
139 if hasattr(args, key) and key != "dataset":
140 data_wrapper_args[key] = getattr(args, key)
141 # unworthy code: share `args` between model and model_wrapper
142 for key in inspect.signature(mw_class).parameters.keys():
143 if hasattr(args, key) and key != "model":
144 model_wrapper_args[key] = getattr(args, key)
145
146 # setup data_wrapper
147 dataset_wrapper = dw_class(dataset, **data_wrapper_args)
148
149 args.num_features = dataset.num_features

Callers 15

test_trainFunction · 0.90
test_gin_mutagFunction · 0.90
test_gin_imdb_binaryFunction · 0.90
test_gin_proteinsFunction · 0.90
test_sortpool_mutagFunction · 0.90
test_patchy_san_mutagFunction · 0.90
test_kmeans_coraFunction · 0.90
test_spectral_coraFunction · 0.90
test_prone_coraFunction · 0.90
test_agc_coraFunction · 0.90
test_daegc_coraFunction · 0.90

Calls 10

runMethod · 0.95
set_random_seedFunction · 0.90
build_model_pathFunction · 0.90
build_datasetFunction · 0.90
fetch_model_wrapperFunction · 0.90
fetch_data_wrapperFunction · 0.90
build_modelFunction · 0.90
TrainerClass · 0.90
set_best_configFunction · 0.85
keysMethod · 0.45

Tested by 15

test_trainFunction · 0.72
test_gin_mutagFunction · 0.72
test_gin_imdb_binaryFunction · 0.72
test_gin_proteinsFunction · 0.72
test_sortpool_mutagFunction · 0.72
test_patchy_san_mutagFunction · 0.72
test_kmeans_coraFunction · 0.72
test_spectral_coraFunction · 0.72
test_prone_coraFunction · 0.72
test_agc_coraFunction · 0.72
test_daegc_coraFunction · 0.72