MCPcopy
hub / github.com/tensorflow/tfjs-examples / lossFunction

Method lossFunction

snake-dqn/agent.js:134–152  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

132 // Get a batch of examples from the replay buffer.
133 const batch = this.replayMemory.sample(batchSize);
134 const lossFunction = () => tf.tidy(() => {
135 const stateTensor = getStateTensor(
136 batch.map(example => example[0]), this.game.height, this.game.width);
137 const actionTensor = tf.tensor1d(
138 batch.map(example => example[1]), 'int32');
139 const qs = this.onlineNetwork.apply(stateTensor, {training: true})
140 .mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);
141
142 const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
143 const nextStateTensor = getStateTensor(
144 batch.map(example => example[4]), this.game.height, this.game.width);
145 const nextMaxQTensor =
146 this.targetNetwork.predict(nextStateTensor).max(-1);
147 const doneMask = tf.scalar(1).sub(
148 tf.tensor1d(batch.map(example => example[3])).asType('float32'));
149 const targetQs =
150 rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma));
151 return tf.losses.meanSquaredError(targetQs, qs);
152 });
153
154 // Calculate the gradients of the loss function with repsect to the weights
155 // of the online DQN.

Callers

nothing calls this directly

Calls 3

getStateTensorFunction · 0.90
applyMethod · 0.45
predictMethod · 0.45

Tested by

no test coverage detected