MCPcopy Index your code
hub / github.com/shenweichen/GraphEmbedding / train

Method train

ge/models/sdne.py:106–150  ·  view source on GitHub ↗
(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1)

Source from the content-addressed store, hash-verified

104 self.get_embeddings()
105
106 def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1):
107 adjacency = self.A.toarray().astype(np.float32)
108 laplacian = self.L.toarray().astype(np.float32)
109 if batch_size >= self.node_size:
110 if batch_size > self.node_size:
111 print('batch_size({0}) > node_size({1}),set batch_size = {1}'.format(
112 batch_size, self.node_size))
113 batch_size = self.node_size
114 return self.model.fit(
115 adjacency,
116 [adjacency, laplacian],
117 batch_size=batch_size,
118 epochs=epochs,
119 initial_epoch=initial_epoch,
120 verbose=verbose,
121 shuffle=False,
122 )
123 else:
124 steps_per_epoch = (self.node_size - 1) // batch_size + 1
125 hist = History()
126 hist.set_model(self.model)
127 hist.on_train_begin()
128 logs = {}
129 for epoch in range(initial_epoch, epochs):
130 start_time = time.time()
131 losses = np.zeros(3)
132 for i in range(steps_per_epoch):
133 index = np.arange(
134 i * batch_size, min((i + 1) * batch_size, self.node_size))
135 A_train = adjacency[index, :]
136 L_mat_train = laplacian[index][:, index]
137 batch_losses = np.asarray(self.model.train_on_batch(A_train, [A_train, L_mat_train]))
138 losses += batch_losses
139 losses = losses / steps_per_epoch
140
141 logs['loss'] = losses[0]
142 logs['2nd_loss'] = losses[1]
143 logs['1st_loss'] = losses[2]
144 epoch_time = int(time.time() - start_time)
145 hist.on_epoch_end(epoch, logs)
146 if verbose > 0:
147 print('Epoch {0}/{1}'.format(epoch + 1, epochs))
148 print('{0}s - loss: {1: .4f} - 2nd_loss: {2: .4f} - 1st_loss: {3: .4f}'.format(
149 epoch_time, losses[0], losses[1], losses[2]))
150 return hist
151
152 def evaluate(self, ):
153 adjacency = self.A.toarray().astype(np.float32)

Callers 2

test_SDNEFunction · 0.95
mainFunction · 0.95

Calls

no outgoing calls

Tested by 1

test_SDNEFunction · 0.76