MCPcopy
hub / github.com/THUDM/CogDL / main

Function main

examples/graphmae/main_graph.py:109–192  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

107
108
109def main(args):
110 device = args.device if args.device >= 0 else "cpu"
111 seeds = args.seeds
112 dataset_name = args.dataset
113 max_epoch = args.max_epoch
114 max_epoch_f = args.max_epoch_f
115 num_hidden = args.num_hidden
116 num_layers = args.num_layers
117 encoder_type = args.encoder
118 decoder_type = args.decoder
119 replace_rate = args.replace_rate
120
121 optim_type = args.optimizer
122 loss_fn = args.loss_fn
123
124 lr = args.lr
125 weight_decay = args.weight_decay
126 lr_f = args.lr_f
127 weight_decay_f = args.weight_decay_f
128 linear_prob = args.linear_prob
129 load_model = args.load_model
130 save_model = args.save_model
131 logs = args.logging
132 use_scheduler = args.scheduler
133 pooling = args.pooling
134 deg4feat = args.deg4feat
135 batch_size = args.batch_size
136
137 graphs, (num_features, num_classes) = load_graph_classification_dataset(dataset_name, deg4feat=deg4feat)
138 args.num_features = num_features
139
140 train_loader = DataLoader(graphs, collate_fn=collate_fn, batch_size=batch_size, pin_memory=True)
141 eval_loader = DataLoader(graphs, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
142
143 if pooling == "mean":
144 pooler = batch_mean_pooling
145 elif pooling == "max":
146 pooler = batch_max_pooling
147 elif pooling == "sum":
148 pooler = batch_sum_pooling
149 else:
150 raise NotImplementedError
151
152 acc_list = []
153 for i, seed in enumerate(seeds):
154 print(f"####### Run {i} for seed {seed}")
155 set_random_seed(seed)
156
157 if logs:
158 logger = TBLogger(name=f"{dataset_name}_loss_{loss_fn}_rpr_{replace_rate}_nh_{num_hidden}_nl_{num_layers}_lr_{lr}_mp_{max_epoch}_mpf_{max_epoch_f}_wd_{weight_decay}_wdf_{weight_decay_f}_{encoder_type}_{decoder_type}")
159 else:
160 logger = None
161
162 model = build_model(args)
163 model.to(device)
164 optimizer = create_optimizer(optim_type, model, lr, weight_decay)
165
166 if use_scheduler:

Callers 1

main_graph.pyFile · 0.70

Calls 10

DataLoaderClass · 0.90
set_random_seedFunction · 0.90
TBLoggerClass · 0.90
build_modelFunction · 0.90
create_optimizerFunction · 0.90
pretrainFunction · 0.70
toMethod · 0.45
evalMethod · 0.45

Tested by

no test coverage detected