Train for one step
(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None)
| 656 | |
| 657 | @tf.function |
| 658 | def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): |
| 659 | """Train for one step""" |
| 660 | with tf.GradientTape() as tape: |
| 661 | y_pred = network(X_batch) |
| 662 | _loss = cost(y_pred, y_batch) |
| 663 | |
| 664 | grad = tape.gradient(_loss, network.trainable_weights) |
| 665 | train_op.apply_gradients(zip(grad, network.trainable_weights)) |
| 666 | |
| 667 | if acc is not None: |
| 668 | _acc = acc(y_pred, y_batch) |
| 669 | return _loss, _acc |
| 670 | else: |
| 671 | return _loss, None |
| 672 | |
| 673 | |
| 674 | # @tf.function # FIXME : enable tf.function will cause some bugs in numpy, need fixing |
no test coverage detected
searching dependent graphs…