MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / main

Function main

snake-dqn/train.js:237–264  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

235}
236
237async function main() {
238 const args = parseArguments();
239 if (args.gpu) {
240 tf = require('@tensorflow/tfjs-node-gpu');
241 } else {
242 tf = require('@tensorflow/tfjs-node');
243 }
244 console.log(`args: ${JSON.stringify(args, null, 2)}`);
245
246 const game = new SnakeGame({
247 height: args.height,
248 width: args.width,
249 numFruits: args.numFruits,
250 initLen: args.initLen
251 });
252 const agent = new SnakeGameAgent(game, {
253 replayBufferSize: args.replayBufferSize,
254 epsilonInit: args.epsilonInit,
255 epsilonFinal: args.epsilonFinal,
256 epsilonDecayFrames: args.epsilonDecayFrames,
257 learningRate: args.learningRate
258 });
259
260 await train(
261 agent, args.batchSize, args.gamma, args.learningRate,
262 args.cumulativeRewardThreshold, args.maxNumFrames,
263 args.syncEveryFrames, args.savePath, args.logDir);
264}
265
266if (require.main === module) {
267 main();

Callers 1

train.jsFile · 0.70

Calls 2

parseArgumentsFunction · 0.70
trainFunction · 0.70

Tested by

no test coverage detected