MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / _train_step

Function _train_step

examples/quantized_net/tutorial_quanconv_cifar10.py:133–143  ·  view source on GitHub ↗
(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None)

Source from the content-addressed store, hash-verified

131
132
133def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None):
134 with tf.GradientTape() as tape:
135 y_pred = network(X_batch)
136 _loss = cost(y_pred, y_batch)
137 grad = tape.gradient(_loss, network.trainable_weights)
138 train_op.apply_gradients(zip(grad, network.trainable_weights))
139 if acc is not None:
140 _acc = acc(y_pred, y_batch)
141 return _loss, _acc
142 else:
143 return _loss, None
144
145
146def accuracy(_logits, y_batch):

Callers 1

Calls 2

gradientMethod · 0.80
accFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…