MCPcopy
hub / github.com/THUDM/CogDL / run

Method run

cogdl/trainer/trainer.py:183–217  ·  view source on GitHub ↗
(self, model_w: ModelWrapper, dataset_w: DataWrapper)

Source from the content-addressed store, hash-verified

181 return [i for i in device_ids], len(device_ids)
182
183 def run(self, model_w: ModelWrapper, dataset_w: DataWrapper):
184 # for network/graph embedding models
185 if isinstance(model_w, EmbeddingModelWrapper):
186 return EmbeddingTrainer(self.save_emb_path, self.load_emb_path).run(model_w, dataset_w)
187
188 print("Model Parameters:", sum(p.numel() for p in model_w.parameters()))
189
190 # for deep learning models
191 # set default loss_fn and evaluator for model_wrapper
192 # mainly for in-cogdl setting
193 model_w.default_loss_fn = dataset_w.get_default_loss_fn()
194 model_w.default_evaluator = dataset_w.get_default_evaluator()
195 model_w.set_evaluation_metric()
196
197 if self.resume_training:
198 model_w = load_model(model_w, self.checkpoint_path).to(self.devices[0])
199
200 if self.distributed_training:
201 torch.multiprocessing.set_sharing_strategy("file_system")
202 self.dist_train(model_w, dataset_w)
203 else:
204 self.train(self.devices[0], model_w, dataset_w)
205 best_model_w = load_model(model_w, self.checkpoint_path).to(self.devices[0])
206
207 if self.return_model:
208 return best_model_w.model
209
210 final_test = self.evaluate(best_model_w, dataset_w)
211
212 # clear the GPU memory
213 dataset = dataset_w.get_dataset()
214 if isinstance(dataset.data, Graph) or hasattr(dataset.data, "graphs"):
215 dataset.data.to("cpu")
216
217 return final_test
218
219 def evaluate(self, model_w: ModelWrapper, dataset_w: DataWrapper, cpu=False):
220 if cpu:

Callers 4

train_modelFunction · 0.95
test_adversarial_trainFunction · 0.95
trainFunction · 0.95
trainFunction · 0.95

Calls 10

dist_trainMethod · 0.95
trainMethod · 0.95
evaluateMethod · 0.95
EmbeddingTrainerClass · 0.90
load_modelFunction · 0.90
get_default_loss_fnMethod · 0.80
get_default_evaluatorMethod · 0.80
set_evaluation_metricMethod · 0.80
toMethod · 0.45
get_datasetMethod · 0.45

Tested by 2

train_modelFunction · 0.76
test_adversarial_trainFunction · 0.76