MCPcopy
hub / github.com/THUDM/CogDL / train

Function train

examples/simple_trafficPre/example.py:41–167  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

39
40
41def train(args): # noqa: C901
42 if isinstance(args.dataset, list):
43 args.dataset = args.dataset[0]
44 if isinstance(args.model, list):
45 args.model = args.model[0]
46 if isinstance(args.seed, list):
47 args.seed = args.seed[0]
48 if isinstance(args.split, list):
49 args.split = args.split[0]
50 # dataset='cora', model='gcn', seed=1, split=0
51 set_random_seed(args.seed)
52
53
54 model_name = args.model if isinstance(args.model, str) else args.model.model_name
55 dw_name = args.dw if isinstance(args.dw, str) else args.dw.__name__
56 mw_name = args.mw if isinstance(args.mw, str) else args.mw.__name__
57
58 print(
59 f"""
60|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|
61 *** Running (`{args.dataset}`, `{model_name}`, `{dw_name}`, `{mw_name}`)
62|-------------------------------------{'-' * (
63 len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|"""
64 )
65
66
67 dataset = build_dataset(args)
68 mw_class = fetch_model_wrapper(args.mw)
69 dw_class = fetch_data_wrapper(args.dw)
70
71 if mw_class is None:
72 raise NotImplementedError("`model wrapper(--mw)` must be specified.")
73
74 if dw_class is None:
75 raise NotImplementedError("`data wrapper(--dw)` must be specified.")
76
77 data_wrapper_args = dict()
78 model_wrapper_args = dict()
79
80 data_wrapper_args['batch_size'] = args.batch_size
81 data_wrapper_args['n_his'] = args.n_his
82 data_wrapper_args['n_pred'] = args.n_pred
83 data_wrapper_args['train_prop'] = args.train_prop
84 data_wrapper_args['val_prop'] = args.val_prop
85 data_wrapper_args['test_prop'] = args.test_prop
86 data_wrapper_args['pred_length'] = args.pred_length
87
88 dataset_wrapper = dw_class(dataset, **data_wrapper_args)
89
90 args.num_features = dataset.num_features
91 if hasattr(dataset, "num_nodes"):
92 args.num_nodes = dataset.num_nodes
93 if hasattr(dataset, "num_edges"):
94 args.num_edges = dataset.num_edges
95 if hasattr(dataset, "num_edge"):
96 args.num_edge = dataset.num_edge
97 if hasattr(dataset, "max_graph_size"):
98 args.max_graph_size = dataset.max_graph_size

Callers 1

raw_experimentFunction · 0.70

Calls 8

runMethod · 0.95
set_random_seedFunction · 0.90
build_datasetFunction · 0.90
fetch_model_wrapperFunction · 0.90
fetch_data_wrapperFunction · 0.90
build_modelFunction · 0.90
TrainerClass · 0.90
get_pre_timestampMethod · 0.45

Tested by

no test coverage detected