()
| 132 | // Get a batch of examples from the replay buffer. |
| 133 | const batch = this.replayMemory.sample(batchSize); |
| 134 | const lossFunction = () => tf.tidy(() => { |
| 135 | const stateTensor = getStateTensor( |
| 136 | batch.map(example => example[0]), this.game.height, this.game.width); |
| 137 | const actionTensor = tf.tensor1d( |
| 138 | batch.map(example => example[1]), 'int32'); |
| 139 | const qs = this.onlineNetwork.apply(stateTensor, {training: true}) |
| 140 | .mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1); |
| 141 | |
| 142 | const rewardTensor = tf.tensor1d(batch.map(example => example[2])); |
| 143 | const nextStateTensor = getStateTensor( |
| 144 | batch.map(example => example[4]), this.game.height, this.game.width); |
| 145 | const nextMaxQTensor = |
| 146 | this.targetNetwork.predict(nextStateTensor).max(-1); |
| 147 | const doneMask = tf.scalar(1).sub( |
| 148 | tf.tensor1d(batch.map(example => example[3])).asType('float32')); |
| 149 | const targetQs = |
| 150 | rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma)); |
| 151 | return tf.losses.meanSquaredError(targetQs, qs); |
| 152 | }); |
| 153 | |
| 154 | // Calculate the gradients of the loss function with repsect to the weights |
| 155 | // of the online DQN. |
nothing calls this directly
no test coverage detected