MCPcopy
hub / github.com/THUDM/CogDL / train

Method train

examples/bgrl/train.py:63–106  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

61 self._scheduler = optim.lr_scheduler.LambdaLR(self._optimizer, lr_lambda=lr_scheduler)
62
63 def train(self):
64 # get initial test results
65 print(self._args)
66 print("start training!")
67
68 print("Initial Evaluation...")
69 self.infer_embeddings()
70 test_best, test_std_best = self.evaluate()
71 print("test: {:.4f}".format(test_best))
72
73 # start training
74 self._model.train()
75 for epoch in range(self._args.epochs):
76
77 self._dataset.to(self._device)
78
79 augmentation = utils.Augmentation(float(self._args.aug_params[0]), float(self._args.aug_params[1]), float(self._args.aug_params[2]), float(self._args.aug_params[3]))
80 view1, view2 = augmentation._feature_masking(self._dataset, self._device)
81
82 v1_output, v2_output, loss = self._model(
83 x1=view1.x, x2=view2.x, graph_v1=view1, graph_v2=view2,
84 edge_weight_v1=view1.edge_attr, edge_weight_v2=view2.edge_attr)
85
86 self._optimizer.zero_grad()
87 loss.backward()
88 self._optimizer.step()
89 self._scheduler.step()
90 self._model.update_moving_average()
91 sys.stdout.write('\rEpoch {}/{}, loss {:.4f}, lr {}'.format(epoch + 1, self._args.epochs, loss.data, self._optimizer.param_groups[0]['lr']))
92 sys.stdout.flush()
93
94 if (epoch + 1) % self._args.cache_step == 0:
95 print("")
96 print("\nEvaluating {}th epoch..".format(epoch + 1))
97
98 self.infer_embeddings()
99 test_acc, test_std = self.evaluate()
100
101 self.writer.add_scalar("stats/learning_rate", self._optimizer.param_groups[0]["lr"] , epoch + 1)
102 self.writer.add_scalar("accs/test_acc", test_acc, epoch + 1)
103 print("test: {:.4f} \n".format(test_acc))
104
105 print()
106 print("Training Done!")
107
108 def infer_embeddings(self):
109

Callers 15

train_evalFunction · 0.95
infer_embeddingsMethod · 0.45
trainFunction · 0.45
preprocessingMethod · 0.45
pretrainFunction · 0.45
pretrainFunction · 0.45
pretrainFunction · 0.45
trainFunction · 0.45
pretrainFunction · 0.45

Calls 8

infer_embeddingsMethod · 0.95
evaluateMethod · 0.95
_feature_maskingMethod · 0.95
update_moving_averageMethod · 0.80
toMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected