| 21 | |
| 22 | |
| 23 | def save(): |
| 24 | print('This is save') |
| 25 | # build neural network |
| 26 | tf_x = tf.placeholder(tf.float32, x.shape) # input x |
| 27 | tf_y = tf.placeholder(tf.float32, y.shape) # input y |
| 28 | l = tf.layers.dense(tf_x, 10, tf.nn.relu) # hidden layer |
| 29 | o = tf.layers.dense(l, 1) # output layer |
| 30 | loss = tf.losses.mean_squared_error(tf_y, o) # compute cost |
| 31 | train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss) |
| 32 | |
| 33 | sess = tf.Session() |
| 34 | sess.run(tf.global_variables_initializer()) # initialize var in graph |
| 35 | |
| 36 | saver = tf.train.Saver() # define a saver for saving and restoring |
| 37 | |
| 38 | for step in range(100): # train |
| 39 | sess.run(train_op, {tf_x: x, tf_y: y}) |
| 40 | |
| 41 | saver.save(sess, './params', write_meta_graph=False) # meta_graph is not recommended |
| 42 | |
| 43 | # plotting |
| 44 | pred, l = sess.run([o, loss], {tf_x: x, tf_y: y}) |
| 45 | plt.figure(1, figsize=(10, 5)) |
| 46 | plt.subplot(121) |
| 47 | plt.scatter(x, y) |
| 48 | plt.plot(x, pred, 'r-', lw=5) |
| 49 | plt.text(-1, 1.2, 'Save Loss=%.4f' % l, fontdict={'size': 15, 'color': 'red'}) |
| 50 | |
| 51 | |
| 52 | def reload(): |