(self, image)
| 62 | A DQN model for 2D/3D (image) observations. |
| 63 | """ |
| 64 | def _get_DQN_prediction(self, image): |
| 65 | assert image.shape.rank in [4, 5], image.shape |
| 66 | # image: N, H, W, (C), Hist |
| 67 | if image.shape.rank == 5: |
| 68 | # merge C & Hist |
| 69 | image = tf.reshape( |
| 70 | image, |
| 71 | [-1] + list(self.state_shape[:2]) + [self.state_shape[2] * FRAME_HISTORY]) |
| 72 | |
| 73 | image = image / 255.0 |
| 74 | with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True): |
| 75 | l = (LinearWrap(image) |
| 76 | # Nature architecture |
| 77 | .Conv2D('conv0', 32, 8, strides=4) |
| 78 | .Conv2D('conv1', 64, 4, strides=2) |
| 79 | .Conv2D('conv2', 64, 3) |
| 80 | |
| 81 | # architecture used for the figure in the README, slower but takes fewer iterations to converge |
| 82 | # .Conv2D('conv0', out_channel=32, kernel_shape=5) |
| 83 | # .MaxPooling('pool0', 2) |
| 84 | # .Conv2D('conv1', out_channel=32, kernel_shape=5) |
| 85 | # .MaxPooling('pool1', 2) |
| 86 | # .Conv2D('conv2', out_channel=64, kernel_shape=4) |
| 87 | # .MaxPooling('pool2', 2) |
| 88 | # .Conv2D('conv3', out_channel=64, kernel_shape=3) |
| 89 | |
| 90 | .FullyConnected('fc0', 512) |
| 91 | .tf.nn.leaky_relu(alpha=0.01)()) |
| 92 | if self.method != 'Dueling': |
| 93 | Q = FullyConnected('fct', l, self.num_actions) |
| 94 | else: |
| 95 | # Dueling DQN |
| 96 | V = FullyConnected('fctV', l, 1) |
| 97 | As = FullyConnected('fctA', l, self.num_actions) |
| 98 | Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True)) |
| 99 | return tf.identity(Q, name='Qvalue') |
| 100 | |
| 101 | |
| 102 | def get_config(model): |
nothing calls this directly
no test coverage detected