* Perform training on a randomly sampled batch from the replay buffer. * * @param {number} batchSize Batch size. * @param {number} gamma Reward discount rate. Must be >= 0 and <= 1. * @param {tf.train.Optimizer} optimizer The optimizer object used to update * the weights of the onli
(batchSize, gamma, optimizer)
| 129 | * the weights of the online network. |
| 130 | */ |
| 131 | trainOnReplayBatch(batchSize, gamma, optimizer) { |
| 132 | // Get a batch of examples from the replay buffer. |
| 133 | const batch = this.replayMemory.sample(batchSize); |
| 134 | const lossFunction = () => tf.tidy(() => { |
| 135 | const stateTensor = getStateTensor( |
| 136 | batch.map(example => example[0]), this.game.height, this.game.width); |
| 137 | const actionTensor = tf.tensor1d( |
| 138 | batch.map(example => example[1]), 'int32'); |
| 139 | const qs = this.onlineNetwork.apply(stateTensor, {training: true}) |
| 140 | .mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1); |
| 141 | |
| 142 | const rewardTensor = tf.tensor1d(batch.map(example => example[2])); |
| 143 | const nextStateTensor = getStateTensor( |
| 144 | batch.map(example => example[4]), this.game.height, this.game.width); |
| 145 | const nextMaxQTensor = |
| 146 | this.targetNetwork.predict(nextStateTensor).max(-1); |
| 147 | const doneMask = tf.scalar(1).sub( |
| 148 | tf.tensor1d(batch.map(example => example[3])).asType('float32')); |
| 149 | const targetQs = |
| 150 | rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma)); |
| 151 | return tf.losses.meanSquaredError(targetQs, qs); |
| 152 | }); |
| 153 | |
| 154 | // Calculate the gradients of the loss function with repsect to the weights |
| 155 | // of the online DQN. |
| 156 | const grads = tf.variableGrads(lossFunction); |
| 157 | // Use the gradients to update the online DQN's weights. |
| 158 | optimizer.applyGradients(grads.grads); |
| 159 | tf.dispose(grads); |
| 160 | // TODO(cais): Return the loss value here? |
| 161 | } |
| 162 | } |
no test coverage detected