(
agent, batchSize, gamma, learningRate, cumulativeRewardThreshold,
maxNumFrames, syncEveryFrames, savePath, logDir)
| 65 | * during the training. Optional. |
| 66 | */ |
| 67 | export async function train( |
| 68 | agent, batchSize, gamma, learningRate, cumulativeRewardThreshold, |
| 69 | maxNumFrames, syncEveryFrames, savePath, logDir) { |
| 70 | let summaryWriter; |
| 71 | if (logDir != null) { |
| 72 | summaryWriter = tf.node.summaryFileWriter(logDir); |
| 73 | } |
| 74 | |
| 75 | for (let i = 0; i < agent.replayBufferSize; ++i) { |
| 76 | agent.playStep(); |
| 77 | } |
| 78 | |
| 79 | // Moving averager: cumulative reward across 100 most recent 100 episodes. |
| 80 | const rewardAverager100 = new MovingAverager(100); |
| 81 | // Moving averager: fruits eaten across 100 most recent 100 episodes. |
| 82 | const eatenAverager100 = new MovingAverager(100); |
| 83 | |
| 84 | const optimizer = tf.train.adam(learningRate); |
| 85 | let tPrev = new Date().getTime(); |
| 86 | let frameCountPrev = agent.frameCount; |
| 87 | let averageReward100Best = -Infinity; |
| 88 | while (true) { |
| 89 | agent.trainOnReplayBatch(batchSize, gamma, optimizer); |
| 90 | const {cumulativeReward, done, fruitsEaten} = agent.playStep(); |
| 91 | if (done) { |
| 92 | const t = new Date().getTime(); |
| 93 | const framesPerSecond = |
| 94 | (agent.frameCount - frameCountPrev) / (t - tPrev) * 1e3; |
| 95 | tPrev = t; |
| 96 | frameCountPrev = agent.frameCount; |
| 97 | |
| 98 | rewardAverager100.append(cumulativeReward); |
| 99 | eatenAverager100.append(fruitsEaten); |
| 100 | const averageReward100 = rewardAverager100.average(); |
| 101 | const averageEaten100 = eatenAverager100.average(); |
| 102 | |
| 103 | console.log( |
| 104 | `Frame #${agent.frameCount}: ` + |
| 105 | `cumulativeReward100=${averageReward100.toFixed(1)}; ` + |
| 106 | `eaten100=${averageEaten100.toFixed(2)} ` + |
| 107 | `(epsilon=${agent.epsilon.toFixed(3)}) ` + |
| 108 | `(${framesPerSecond.toFixed(1)} frames/s)`); |
| 109 | if (summaryWriter != null) { |
| 110 | summaryWriter.scalar( |
| 111 | 'cumulativeReward100', averageReward100, agent.frameCount); |
| 112 | summaryWriter.scalar('eaten100', averageEaten100, agent.frameCount); |
| 113 | summaryWriter.scalar('epsilon', agent.epsilon, agent.frameCount); |
| 114 | summaryWriter.scalar( |
| 115 | 'framesPerSecond', framesPerSecond, agent.frameCount); |
| 116 | } |
| 117 | if (averageReward100 >= cumulativeRewardThreshold || |
| 118 | agent.frameCount >= maxNumFrames) { |
| 119 | // TODO(cais): Save online network. |
| 120 | break; |
| 121 | } |
| 122 | if (averageReward100 > averageReward100Best) { |
| 123 | averageReward100Best = averageReward100; |
| 124 | if (savePath != null) { |
no test coverage detected