MCPcopy
hub / github.com/tensorflow/tfjs-examples / createDeepQNetwork

Function createDeepQNetwork

snake-dqn/dqn.js:20–61  ·  view source on GitHub ↗
(h, w, numActions)

Source from the content-addressed store, hash-verified

18import * as tf from '@tensorflow/tfjs';
19
20export function createDeepQNetwork(h, w, numActions) {
21 if (!(Number.isInteger(h) && h > 0)) {
22 throw new Error(`Expected height to be a positive integer, but got ${h}`);
23 }
24 if (!(Number.isInteger(w) && w > 0)) {
25 throw new Error(`Expected width to be a positive integer, but got ${w}`);
26 }
27 if (!(Number.isInteger(numActions) && numActions > 1)) {
28 throw new Error(
29 `Expected numActions to be a integer greater than 1, ` +
30 `but got ${numActions}`);
31 }
32
33 const model = tf.sequential();
34 model.add(tf.layers.conv2d({
35 filters: 128,
36 kernelSize: 3,
37 strides: 1,
38 activation: 'relu',
39 inputShape: [h, w, 2]
40 }));
41 model.add(tf.layers.batchNormalization());
42 model.add(tf.layers.conv2d({
43 filters: 256,
44 kernelSize: 3,
45 strides: 1,
46 activation: 'relu'
47 }));
48 model.add(tf.layers.batchNormalization());
49 model.add(tf.layers.conv2d({
50 filters: 256,
51 kernelSize: 3,
52 strides: 1,
53 activation: 'relu'
54 }));
55 model.add(tf.layers.flatten());
56 model.add(tf.layers.dense({units: 100, activation: 'relu'}));
57 model.add(tf.layers.dropout({rate: 0.25}));
58 model.add(tf.layers.dense({units: numActions}));
59
60 return model;
61}
62
63/**
64 * Copy the weights from a source deep-Q network to another.

Callers 2

dqn_test.jsFile · 0.90
constructorMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected