MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / _train_step

Function _train_step

examples/quantized_net/tutorial_quanconv_mnist.py:51–61  ·  view source on GitHub ↗
(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None)

Source from the content-addressed store, hash-verified

49
50
51def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None):
52 with tf.GradientTape() as tape:
53 y_pred = network(X_batch)
54 _loss = cost(y_pred, y_batch)
55 grad = tape.gradient(_loss, network.trainable_weights)
56 train_op.apply_gradients(zip(grad, network.trainable_weights))
57 if acc is not None:
58 _acc = acc(y_pred, y_batch)
59 return _loss, _acc
60 else:
61 return _loss, None
62
63
64def accuracy(_logits, y_batch):

Callers 1

Calls 2

gradientMethod · 0.80
accFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…