(self, b_o, b_index, b_m)
| 266 | |
| 267 | @tf.function |
| 268 | def _train_func(self, b_o, b_index, b_m): |
| 269 | with tf.GradientTape() as tape: |
| 270 | b_dist_a = tf.gather_nd(self.qnet(b_o), b_index) |
| 271 | loss = tf.reduce_mean(tf.negative(tf.reduce_sum(b_dist_a * b_m, 1))) |
| 272 | |
| 273 | grad = tape.gradient(loss, self.qnet.trainable_weights) |
| 274 | self.optimizer.apply_gradients(zip(grad, self.qnet.trainable_weights)) |
| 275 | |
| 276 | |
| 277 | # ############################# Trainer ################################### |