MCPcopy
hub / github.com/tensorlayer/TensorLayer / _train_step

Function _train_step

tensorlayer/utils.py:658–671  ·  view source on GitHub ↗

Train for one step

(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None)

Source from the content-addressed store, hash-verified

656
657@tf.function
658def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None):
659 """Train for one step"""
660 with tf.GradientTape() as tape:
661 y_pred = network(X_batch)
662 _loss = cost(y_pred, y_batch)
663
664 grad = tape.gradient(_loss, network.trainable_weights)
665 train_op.apply_gradients(zip(grad, network.trainable_weights))
666
667 if acc is not None:
668 _acc = acc(y_pred, y_batch)
669 return _loss, _acc
670 else:
671 return _loss, None
672
673
674# @tf.function # FIXME : enable tf.function will cause some bugs in numpy, need fixing

Callers 1

train_epochFunction · 0.70

Calls 2

gradientMethod · 0.80
accFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…