| 195 | class DQN(object): |
| 196 | |
| 197 | def __init__(self): |
| 198 | model = MLP if qnet_type == 'MLP' else CNN |
| 199 | self.qnet = model('q') |
| 200 | if args.train: |
| 201 | self.qnet.train() |
| 202 | self.targetqnet = model('targetq') |
| 203 | self.targetqnet.infer() |
| 204 | sync(self.qnet, self.targetqnet) |
| 205 | else: |
| 206 | self.qnet.infer() |
| 207 | self.load(args.save_path) |
| 208 | self.niter = 0 |
| 209 | if clipnorm is not None: |
| 210 | self.optimizer = tf.optimizers.Adam(learning_rate=lr, clipnorm=clipnorm) |
| 211 | else: |
| 212 | self.optimizer = tf.optimizers.Adam(learning_rate=lr) |
| 213 | |
| 214 | def get_action(self, obv): |
| 215 | eps = epsilon(self.niter) |