(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None)
| 131 | |
| 132 | |
| 133 | def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): |
| 134 | with tf.GradientTape() as tape: |
| 135 | y_pred = network(X_batch) |
| 136 | _loss = cost(y_pred, y_batch) |
| 137 | grad = tape.gradient(_loss, network.trainable_weights) |
| 138 | train_op.apply_gradients(zip(grad, network.trainable_weights)) |
| 139 | if acc is not None: |
| 140 | _acc = acc(y_pred, y_batch) |
| 141 | return _loss, _acc |
| 142 | else: |
| 143 | return _loss, None |
| 144 | |
| 145 | |
| 146 | def accuracy(_logits, y_batch): |
no test coverage detected
searching dependent graphs…