MCPcopy
hub / github.com/thunlp/OpenKE / run

Method run

openke/config/Trainer.py:56–99  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

54 return loss.item()
55
56 def run(self):
57 if self.use_gpu:
58 self.model.cuda()
59
60 if self.optimizer != None:
61 pass
62 elif self.opt_method == "Adagrad" or self.opt_method == "adagrad":
63 self.optimizer = optim.Adagrad(
64 self.model.parameters(),
65 lr=self.alpha,
66 lr_decay=self.lr_decay,
67 weight_decay=self.weight_decay,
68 )
69 elif self.opt_method == "Adadelta" or self.opt_method == "adadelta":
70 self.optimizer = optim.Adadelta(
71 self.model.parameters(),
72 lr=self.alpha,
73 weight_decay=self.weight_decay,
74 )
75 elif self.opt_method == "Adam" or self.opt_method == "adam":
76 self.optimizer = optim.Adam(
77 self.model.parameters(),
78 lr=self.alpha,
79 weight_decay=self.weight_decay,
80 )
81 else:
82 self.optimizer = optim.SGD(
83 self.model.parameters(),
84 lr = self.alpha,
85 weight_decay=self.weight_decay,
86 )
87 print("Finish initializing...")
88
89 training_range = tqdm(range(self.train_times))
90 for epoch in training_range:
91 res = 0.0
92 for data in self.data_loader:
93 loss = self.train_one_step(data)
94 res += loss
95 training_range.set_description("Epoch %d | loss: %f" % (epoch, res))
96
97 if self.save_steps and self.checkpoint_dir and (epoch + 1) % self.save_steps == 0:
98 print("Epoch %d has finished, saving..." % (epoch))
99 self.model.save_checkpoint(os.path.join(self.checkpoint_dir + "-" + str(epoch) + ".ckpt"))
100
101 def set_model(self, model):
102 self.model = model

Calls 2

train_one_stepMethod · 0.95
save_checkpointMethod · 0.80

Tested by

no test coverage detected