()
| 235 | } |
| 236 | |
| 237 | async function main() { |
| 238 | const args = parseArguments(); |
| 239 | if (args.gpu) { |
| 240 | tf = require('@tensorflow/tfjs-node-gpu'); |
| 241 | } else { |
| 242 | tf = require('@tensorflow/tfjs-node'); |
| 243 | } |
| 244 | console.log(`args: ${JSON.stringify(args, null, 2)}`); |
| 245 | |
| 246 | const game = new SnakeGame({ |
| 247 | height: args.height, |
| 248 | width: args.width, |
| 249 | numFruits: args.numFruits, |
| 250 | initLen: args.initLen |
| 251 | }); |
| 252 | const agent = new SnakeGameAgent(game, { |
| 253 | replayBufferSize: args.replayBufferSize, |
| 254 | epsilonInit: args.epsilonInit, |
| 255 | epsilonFinal: args.epsilonFinal, |
| 256 | epsilonDecayFrames: args.epsilonDecayFrames, |
| 257 | learningRate: args.learningRate |
| 258 | }); |
| 259 | |
| 260 | await train( |
| 261 | agent, args.batchSize, args.gamma, args.learningRate, |
| 262 | args.cumulativeRewardThreshold, args.maxNumFrames, |
| 263 | args.syncEveryFrames, args.savePath, args.logDir); |
| 264 | } |
| 265 | |
| 266 | if (require.main === module) { |
| 267 | main(); |
no test coverage detected