MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / DQN

Class DQN

examples/reinforcement_learning/tutorial_DQN_variants.py:273–363  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

271
272# ############################### DQN #####################################
273class DQN(object):
274
275 def __init__(self):
276 model = MLP if qnet_type == 'MLP' else CNN
277 self.qnet = model('q')
278 if args.train:
279 self.qnet.train()
280 self.targetqnet = model('targetq')
281 self.targetqnet.infer()
282 sync(self.qnet, self.targetqnet)
283 else:
284 self.qnet.infer()
285 self.load(args.save_path)
286 self.niter = 0
287 if clipnorm is not None:
288 self.optimizer = tf.optimizers.Adam(learning_rate=lr, clipnorm=clipnorm)
289 else:
290 self.optimizer = tf.optimizers.Adam(learning_rate=lr)
291 self.noise_scale = noise_scale
292
293 def get_action(self, obv):
294 eps = epsilon(self.niter)
295 if args.train:
296 if random.random() < eps:
297 return int(random.random() * out_dim)
298 obv = np.expand_dims(obv, 0).astype('float32') * ob_scale
299 if self.niter < explore_timesteps:
300 self.qnet.noise_scale = self.noise_scale
301 q_ptb = self._qvalues_func(obv).numpy()
302 self.qnet.noise_scale = 0
303 if i % noise_update_freq == 0:
304 q = self._qvalues_func(obv).numpy()
305 kl_ptb = (log_softmax(q, 1) - log_softmax(q_ptb, 1))
306 kl_ptb = np.sum(kl_ptb * softmax(q, 1), 1).mean()
307 kl_explore = -np.log(1 - eps + eps / out_dim)
308 if kl_ptb < kl_explore:
309 self.noise_scale *= 1.01
310 else:
311 self.noise_scale /= 1.01
312 return q_ptb.argmax(1)[0]
313 else:
314 return self._qvalues_func(obv).numpy().argmax(1)[0]
315 else:
316 obv = np.expand_dims(obv, 0).astype('float32') * ob_scale
317 return self._qvalues_func(obv).numpy().argmax(1)[0]
318
319 @tf.function
320 def _qvalues_func(self, obv):
321 return self.qnet(obv)
322
323 def train(self, b_o, b_a, b_r, b_o_, b_d):
324 self._train_func(b_o, b_a, b_r, b_o_, b_d)
325
326 self.niter += 1
327 if self.niter % target_q_update_freq == 0:
328 sync(self.qnet, self.targetqnet)
329 self.save(args.save_path)
330

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected