(h, w, numActions)
| 18 | import * as tf from '@tensorflow/tfjs'; |
| 19 | |
| 20 | export function createDeepQNetwork(h, w, numActions) { |
| 21 | if (!(Number.isInteger(h) && h > 0)) { |
| 22 | throw new Error(`Expected height to be a positive integer, but got ${h}`); |
| 23 | } |
| 24 | if (!(Number.isInteger(w) && w > 0)) { |
| 25 | throw new Error(`Expected width to be a positive integer, but got ${w}`); |
| 26 | } |
| 27 | if (!(Number.isInteger(numActions) && numActions > 1)) { |
| 28 | throw new Error( |
| 29 | `Expected numActions to be a integer greater than 1, ` + |
| 30 | `but got ${numActions}`); |
| 31 | } |
| 32 | |
| 33 | const model = tf.sequential(); |
| 34 | model.add(tf.layers.conv2d({ |
| 35 | filters: 128, |
| 36 | kernelSize: 3, |
| 37 | strides: 1, |
| 38 | activation: 'relu', |
| 39 | inputShape: [h, w, 2] |
| 40 | })); |
| 41 | model.add(tf.layers.batchNormalization()); |
| 42 | model.add(tf.layers.conv2d({ |
| 43 | filters: 256, |
| 44 | kernelSize: 3, |
| 45 | strides: 1, |
| 46 | activation: 'relu' |
| 47 | })); |
| 48 | model.add(tf.layers.batchNormalization()); |
| 49 | model.add(tf.layers.conv2d({ |
| 50 | filters: 256, |
| 51 | kernelSize: 3, |
| 52 | strides: 1, |
| 53 | activation: 'relu' |
| 54 | })); |
| 55 | model.add(tf.layers.flatten()); |
| 56 | model.add(tf.layers.dense({units: 100, activation: 'relu'})); |
| 57 | model.add(tf.layers.dropout({rate: 0.25})); |
| 58 | model.add(tf.layers.dense({units: numActions})); |
| 59 | |
| 60 | return model; |
| 61 | } |
| 62 | |
| 63 | /** |
| 64 | * Copy the weights from a source deep-Q network to another. |
no outgoing calls
no test coverage detected