MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / _get_DQN_prediction

Method _get_DQN_prediction

examples/DeepQNetwork/DQN.py:64–99  ·  view source on GitHub ↗
(self, image)

Source from the content-addressed store, hash-verified

62 A DQN model for 2D/3D (image) observations.
63 """
64 def _get_DQN_prediction(self, image):
65 assert image.shape.rank in [4, 5], image.shape
66 # image: N, H, W, (C), Hist
67 if image.shape.rank == 5:
68 # merge C & Hist
69 image = tf.reshape(
70 image,
71 [-1] + list(self.state_shape[:2]) + [self.state_shape[2] * FRAME_HISTORY])
72
73 image = image / 255.0
74 with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True):
75 l = (LinearWrap(image)
76 # Nature architecture
77 .Conv2D('conv0', 32, 8, strides=4)
78 .Conv2D('conv1', 64, 4, strides=2)
79 .Conv2D('conv2', 64, 3)
80
81 # architecture used for the figure in the README, slower but takes fewer iterations to converge
82 # .Conv2D('conv0', out_channel=32, kernel_shape=5)
83 # .MaxPooling('pool0', 2)
84 # .Conv2D('conv1', out_channel=32, kernel_shape=5)
85 # .MaxPooling('pool1', 2)
86 # .Conv2D('conv2', out_channel=64, kernel_shape=4)
87 # .MaxPooling('pool2', 2)
88 # .Conv2D('conv3', out_channel=64, kernel_shape=3)
89
90 .FullyConnected('fc0', 512)
91 .tf.nn.leaky_relu(alpha=0.01)())
92 if self.method != 'Dueling':
93 Q = FullyConnected('fct', l, self.num_actions)
94 else:
95 # Dueling DQN
96 V = FullyConnected('fctV', l, 1)
97 As = FullyConnected('fctA', l, self.num_actions)
98 Q = tf.add(As, V - tf.reduce_mean(As, 1, keep_dims=True))
99 return tf.identity(Q, name='Qvalue')
100
101
102def get_config(model):

Callers

nothing calls this directly

Calls 5

argscopeFunction · 0.85
PReLUFunction · 0.85
LinearWrapClass · 0.85
FullyConnectedFunction · 0.85
addMethod · 0.45

Tested by

no test coverage detected