MCPcopy
hub / github.com/MrNothing/AI-Blocks / Train

Function Train

Sources/scripts/auto_encoder.py:36–81  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

34 self.Decode()
35
36def Train(self):
37 self.X = tf.placeholder(tf.float32, shape=[None, self._input.input_size], name="x_input")
38
39 z = self.Encoder.Run(self.X)
40 fakeX = self.Decoder.Run(z)
41
42 AE_loss = tf.reduce_mean(tf.pow(fakeX-self.X, 2))
43 AE_solver = tf.train.AdamOptimizer(self.learning_rate).minimize(AE_loss)
44
45 # Initializing the variables
46 init = tf.global_variables_initializer()
47
48 # 'Saver' op to save and restore all the variables
49 saver = tf.train.Saver()
50
51 # Launch the graph
52 with tf.Session() as sess:
53 sess.run(init)
54
55 if len(self.save_path)>0 and os.path.exists(self.save_path+"/model.meta"):
56 # Restore model weights from previously saved model
57 load_path=saver.restore(sess, self.save_path+"/model")
58 print ("Model restored from file: %s" % self.save_path)
59
60 acc_log = []
61
62 for it in range(self.training_iterations):
63 batch = self._input.getNextBatch()
64
65 X_batch = batch[0]
66
67 _, loss = sess.run([AE_solver, AE_loss], feed_dict={self.X: X_batch})
68
69 if it % self.display_step == 0:
70 SetState(self.id, it/self.training_iterations)
71 SendChartData(self.id, "Loss", loss, "#ff0000")
72 if it % self.display_step*10 == 0:
73 test_X = [X_batch[0]]
74 rebuilt_image = sess.run(fakeX, feed_dict = {self.X: test_X})[0]
75 SendImageData(self.id, test_X[0], self._input.image_size[0], self._input.image_size[1], "original")
76 SendImageData(self.id, rebuilt_image, self._input.image_size[0], self._input.image_size[1], "fake")
77
78 if len(self.save_path)>0:
79 # Save model weights to disk
80 s_path = saver.save(sess, self.save_path+"/model")
81 print ("Model saved in file: %s" % s_path)

Callers

nothing calls this directly

Calls 5

SetStateFunction · 0.85
SendChartDataFunction · 0.85
SendImageDataFunction · 0.85
saveMethod · 0.80
RunMethod · 0.45

Tested by

no test coverage detected