PG class
| 57 | |
| 58 | |
| 59 | class PolicyGradient: |
| 60 | """ |
| 61 | PG class |
| 62 | """ |
| 63 | |
| 64 | def __init__(self, state_dim, action_num, learning_rate=0.02, gamma=0.99): |
| 65 | self.gamma = gamma |
| 66 | |
| 67 | self.state_buffer, self.action_buffer, self.reward_buffer = [], [], [] |
| 68 | |
| 69 | input_layer = tl.layers.Input([None, state_dim], tf.float32) |
| 70 | layer = tl.layers.Dense( |
| 71 | n_units=30, act=tf.nn.tanh, W_init=tf.random_normal_initializer(mean=0, stddev=0.3), |
| 72 | b_init=tf.constant_initializer(0.1) |
| 73 | )(input_layer) |
| 74 | all_act = tl.layers.Dense( |
| 75 | n_units=action_num, act=None, W_init=tf.random_normal_initializer(mean=0, stddev=0.3), |
| 76 | b_init=tf.constant_initializer(0.1) |
| 77 | )(layer) |
| 78 | |
| 79 | self.model = tl.models.Model(inputs=input_layer, outputs=all_act) |
| 80 | self.model.train() |
| 81 | self.optimizer = tf.optimizers.Adam(learning_rate) |
| 82 | |
| 83 | def get_action(self, s, greedy=False): |
| 84 | """ |
| 85 | choose action with probabilities. |
| 86 | :param s: state |
| 87 | :param greedy: choose action greedy or not |
| 88 | :return: act |
| 89 | """ |
| 90 | _logits = self.model(np.array([s], np.float32)) |
| 91 | _probs = tf.nn.softmax(_logits).numpy() |
| 92 | if greedy: |
| 93 | return np.argmax(_probs.ravel()) |
| 94 | return tl.rein.choice_action_by_probs(_probs.ravel()) |
| 95 | |
| 96 | def store_transition(self, s, a, r): |
| 97 | """ |
| 98 | store data in memory buffer |
| 99 | :param s: state |
| 100 | :param a: act |
| 101 | :param r: reward |
| 102 | :return: |
| 103 | """ |
| 104 | self.state_buffer.append(np.array([s], np.float32)) |
| 105 | self.action_buffer.append(a) |
| 106 | self.reward_buffer.append(r) |
| 107 | |
| 108 | def learn(self): |
| 109 | """ |
| 110 | update policy parameters via stochastic gradient ascent |
| 111 | :return: None |
| 112 | """ |
| 113 | discounted_reward_buffer_norm = self._discount_and_norm_rewards() |
| 114 | |
| 115 | with tf.GradientTape() as tape: |
| 116 | _logits = self.model(np.vstack(self.state_buffer)) |
no outgoing calls
no test coverage detected
searching dependent graphs…