| 271 | |
| 272 | # ############################### DQN ##################################### |
| 273 | class DQN(object): |
| 274 | |
| 275 | def __init__(self): |
| 276 | model = MLP if qnet_type == 'MLP' else CNN |
| 277 | self.qnet = model('q') |
| 278 | if args.train: |
| 279 | self.qnet.train() |
| 280 | self.targetqnet = model('targetq') |
| 281 | self.targetqnet.infer() |
| 282 | sync(self.qnet, self.targetqnet) |
| 283 | else: |
| 284 | self.qnet.infer() |
| 285 | self.load(args.save_path) |
| 286 | self.niter = 0 |
| 287 | if clipnorm is not None: |
| 288 | self.optimizer = tf.optimizers.Adam(learning_rate=lr, clipnorm=clipnorm) |
| 289 | else: |
| 290 | self.optimizer = tf.optimizers.Adam(learning_rate=lr) |
| 291 | self.noise_scale = noise_scale |
| 292 | |
| 293 | def get_action(self, obv): |
| 294 | eps = epsilon(self.niter) |
| 295 | if args.train: |
| 296 | if random.random() < eps: |
| 297 | return int(random.random() * out_dim) |
| 298 | obv = np.expand_dims(obv, 0).astype('float32') * ob_scale |
| 299 | if self.niter < explore_timesteps: |
| 300 | self.qnet.noise_scale = self.noise_scale |
| 301 | q_ptb = self._qvalues_func(obv).numpy() |
| 302 | self.qnet.noise_scale = 0 |
| 303 | if i % noise_update_freq == 0: |
| 304 | q = self._qvalues_func(obv).numpy() |
| 305 | kl_ptb = (log_softmax(q, 1) - log_softmax(q_ptb, 1)) |
| 306 | kl_ptb = np.sum(kl_ptb * softmax(q, 1), 1).mean() |
| 307 | kl_explore = -np.log(1 - eps + eps / out_dim) |
| 308 | if kl_ptb < kl_explore: |
| 309 | self.noise_scale *= 1.01 |
| 310 | else: |
| 311 | self.noise_scale /= 1.01 |
| 312 | return q_ptb.argmax(1)[0] |
| 313 | else: |
| 314 | return self._qvalues_func(obv).numpy().argmax(1)[0] |
| 315 | else: |
| 316 | obv = np.expand_dims(obv, 0).astype('float32') * ob_scale |
| 317 | return self._qvalues_func(obv).numpy().argmax(1)[0] |
| 318 | |
| 319 | @tf.function |
| 320 | def _qvalues_func(self, obv): |
| 321 | return self.qnet(obv) |
| 322 | |
| 323 | def train(self, b_o, b_a, b_r, b_o_, b_d): |
| 324 | self._train_func(b_o, b_a, b_r, b_o_, b_d) |
| 325 | |
| 326 | self.niter += 1 |
| 327 | if self.niter % target_q_update_freq == 0: |
| 328 | sync(self.qnet, self.targetqnet) |
| 329 | self.save(args.save_path) |
| 330 |
no outgoing calls
no test coverage detected