MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / trainOnReplayBatch

Method trainOnReplayBatch

snake-dqn/agent.js:131–161  ·  view source on GitHub ↗

* Perform training on a randomly sampled batch from the replay buffer. * * @param {number} batchSize Batch size. * @param {number} gamma Reward discount rate. Must be >= 0 and <= 1. * @param {tf.train.Optimizer} optimizer The optimizer object used to update * the weights of the onli

(batchSize, gamma, optimizer)

Source from the content-addressed store, hash-verified

129 * the weights of the online network.
130 */
131 trainOnReplayBatch(batchSize, gamma, optimizer) {
132 // Get a batch of examples from the replay buffer.
133 const batch = this.replayMemory.sample(batchSize);
134 const lossFunction = () => tf.tidy(() => {
135 const stateTensor = getStateTensor(
136 batch.map(example => example[0]), this.game.height, this.game.width);
137 const actionTensor = tf.tensor1d(
138 batch.map(example => example[1]), 'int32');
139 const qs = this.onlineNetwork.apply(stateTensor, {training: true})
140 .mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);
141
142 const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
143 const nextStateTensor = getStateTensor(
144 batch.map(example => example[4]), this.game.height, this.game.width);
145 const nextMaxQTensor =
146 this.targetNetwork.predict(nextStateTensor).max(-1);
147 const doneMask = tf.scalar(1).sub(
148 tf.tensor1d(batch.map(example => example[3])).asType('float32'));
149 const targetQs =
150 rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma));
151 return tf.losses.meanSquaredError(targetQs, qs);
152 });
153
154 // Calculate the gradients of the loss function with repsect to the weights
155 // of the online DQN.
156 const grads = tf.variableGrads(lossFunction);
157 // Use the gradients to update the online DQN's weights.
158 optimizer.applyGradients(grads.grads);
159 tf.dispose(grads);
160 // TODO(cais): Return the loss value here?
161 }
162}

Callers 2

agent_test.jsFile · 0.80
trainFunction · 0.80

Calls 1

sampleMethod · 0.80

Tested by

no test coverage detected