MCPcopy
hub / github.com/tensorflow/tfjs-examples / train

Function train

snake-dqn/train.js:67–138  ·  view source on GitHub ↗
(
    agent, batchSize, gamma, learningRate, cumulativeRewardThreshold,
    maxNumFrames, syncEveryFrames, savePath, logDir)

Source from the content-addressed store, hash-verified

65 * during the training. Optional.
66 */
67export async function train(
68 agent, batchSize, gamma, learningRate, cumulativeRewardThreshold,
69 maxNumFrames, syncEveryFrames, savePath, logDir) {
70 let summaryWriter;
71 if (logDir != null) {
72 summaryWriter = tf.node.summaryFileWriter(logDir);
73 }
74
75 for (let i = 0; i < agent.replayBufferSize; ++i) {
76 agent.playStep();
77 }
78
79 // Moving averager: cumulative reward across 100 most recent 100 episodes.
80 const rewardAverager100 = new MovingAverager(100);
81 // Moving averager: fruits eaten across 100 most recent 100 episodes.
82 const eatenAverager100 = new MovingAverager(100);
83
84 const optimizer = tf.train.adam(learningRate);
85 let tPrev = new Date().getTime();
86 let frameCountPrev = agent.frameCount;
87 let averageReward100Best = -Infinity;
88 while (true) {
89 agent.trainOnReplayBatch(batchSize, gamma, optimizer);
90 const {cumulativeReward, done, fruitsEaten} = agent.playStep();
91 if (done) {
92 const t = new Date().getTime();
93 const framesPerSecond =
94 (agent.frameCount - frameCountPrev) / (t - tPrev) * 1e3;
95 tPrev = t;
96 frameCountPrev = agent.frameCount;
97
98 rewardAverager100.append(cumulativeReward);
99 eatenAverager100.append(fruitsEaten);
100 const averageReward100 = rewardAverager100.average();
101 const averageEaten100 = eatenAverager100.average();
102
103 console.log(
104 `Frame #${agent.frameCount}: ` +
105 `cumulativeReward100=${averageReward100.toFixed(1)}; ` +
106 `eaten100=${averageEaten100.toFixed(2)} ` +
107 `(epsilon=${agent.epsilon.toFixed(3)}) ` +
108 `(${framesPerSecond.toFixed(1)} frames/s)`);
109 if (summaryWriter != null) {
110 summaryWriter.scalar(
111 'cumulativeReward100', averageReward100, agent.frameCount);
112 summaryWriter.scalar('eaten100', averageEaten100, agent.frameCount);
113 summaryWriter.scalar('epsilon', agent.epsilon, agent.frameCount);
114 summaryWriter.scalar(
115 'framesPerSecond', framesPerSecond, agent.frameCount);
116 }
117 if (averageReward100 >= cumulativeRewardThreshold ||
118 agent.frameCount >= maxNumFrames) {
119 // TODO(cais): Save online network.
120 break;
121 }
122 if (averageReward100 > averageReward100Best) {
123 averageReward100Best = averageReward100;
124 if (savePath != null) {

Callers 1

mainFunction · 0.70

Calls 6

appendMethod · 0.95
averageMethod · 0.95
copyWeightsFunction · 0.90
playStepMethod · 0.80
getTimeMethod · 0.80
trainOnReplayBatchMethod · 0.80

Tested by

no test coverage detected