MCPcopy
hub / github.com/THUDM/CogDL / train

Method train

cogdl/trainer/trainer.py:305–434  ·  view source on GitHub ↗
(self, rank, model_w, dataset_w)

Source from the content-addressed store, hash-verified

303 return model_w.to(rank), None
304
305 def train(self, rank, model_w, dataset_w): # noqa: C901
306 model_w, _ = self.initialize(model_w, rank=rank, master_addr=self.master_addr, master_port=self.master_port)
307 self.data_controller.prepare_data_wrapper(dataset_w, rank)
308 self.eval_data_back_to_cpu = dataset_w.data_back_to_cpu
309
310 optimizers, lr_schedulers = self.build_optimizer(model_w)
311 if optimizers[0] is None:
312 return
313
314 est = model_w.set_early_stopping()
315 if isinstance(est, str):
316 est_monitor = est
317 best_index, compare_fn = evaluation_comp(est_monitor)
318 else:
319 assert len(est) == 2
320 est_monitor, est_compare = est
321 best_index, compare_fn = evaluation_comp(est_monitor, est_compare)
322 self.monitor = est_monitor
323 self.evaluation_metric = model_w.evaluation_metric
324
325 best_model_w = None
326
327 scaler = GradScaler() if self.fp16 else None
328
329 patience = 0
330 best_epoch = 0
331 for stage in range(self.nstage):
332 with torch.no_grad():
333 pre_stage_out = model_w.pre_stage(stage, dataset_w)
334 dataset_w.pre_stage(stage, pre_stage_out)
335 self.data_controller.training_proc_per_stage(dataset_w, rank)
336
337 if self.progress_bar == "epoch":
338 epoch_iter = tqdm(range(1, self.epochs + 1))
339 epoch_printer = Printer(epoch_iter.set_description, rank=rank, world_size=self.world_size)
340 else:
341 epoch_iter = range(1, self.epochs + 1)
342 epoch_printer = Printer(print, rank=rank, world_size=self.world_size)
343
344 self.logger.start()
345 print_str_dict = dict()
346 if self.attack is not None:
347 graph = dataset_w.dataset.data
348 graph_backup = copy.deepcopy(graph)
349 graph0 = copy.deepcopy(graph)
350 num_train = torch.sum(graph.train_mask).item()
351 for epoch in epoch_iter:
352 for hook in self.pre_epoch_hooks:
353 hook(self)
354
355 # inductive setting ..
356 dataset_w.train()
357 train_loader = dataset_w.on_train_wrapper()
358 train_dataset = train_loader.get_dataset_from_loader()
359 if hasattr(train_dataset, "shuffle"):
360 train_dataset.shuffle()
361 training_loss = self.train_step(model_w, train_loader, optimizers, lr_schedulers, rank, scaler)
362

Callers 2

runMethod · 0.95
train_stepMethod · 0.45

Calls 15

initializeMethod · 0.95
build_optimizerMethod · 0.95
train_stepMethod · 0.95
validateMethod · 0.95
evaluation_compFunction · 0.90
PrinterClass · 0.90
adj_preprocessFunction · 0.90
updateGraphFunction · 0.90
adj_to_tensorFunction · 0.90
save_modelFunction · 0.90
on_train_wrapperMethod · 0.80

Tested by

no test coverage detected