MCPcopy
hub / github.com/tensorlayer/TensorLayer / PolicyGradient

Class PolicyGradient

examples/reinforcement_learning/tutorial_PG.py:59–161  ·  view source on GitHub ↗

PG class

Source from the content-addressed store, hash-verified

57
58
59class PolicyGradient:
60 """
61 PG class
62 """
63
64 def __init__(self, state_dim, action_num, learning_rate=0.02, gamma=0.99):
65 self.gamma = gamma
66
67 self.state_buffer, self.action_buffer, self.reward_buffer = [], [], []
68
69 input_layer = tl.layers.Input([None, state_dim], tf.float32)
70 layer = tl.layers.Dense(
71 n_units=30, act=tf.nn.tanh, W_init=tf.random_normal_initializer(mean=0, stddev=0.3),
72 b_init=tf.constant_initializer(0.1)
73 )(input_layer)
74 all_act = tl.layers.Dense(
75 n_units=action_num, act=None, W_init=tf.random_normal_initializer(mean=0, stddev=0.3),
76 b_init=tf.constant_initializer(0.1)
77 )(layer)
78
79 self.model = tl.models.Model(inputs=input_layer, outputs=all_act)
80 self.model.train()
81 self.optimizer = tf.optimizers.Adam(learning_rate)
82
83 def get_action(self, s, greedy=False):
84 """
85 choose action with probabilities.
86 :param s: state
87 :param greedy: choose action greedy or not
88 :return: act
89 """
90 _logits = self.model(np.array([s], np.float32))
91 _probs = tf.nn.softmax(_logits).numpy()
92 if greedy:
93 return np.argmax(_probs.ravel())
94 return tl.rein.choice_action_by_probs(_probs.ravel())
95
96 def store_transition(self, s, a, r):
97 """
98 store data in memory buffer
99 :param s: state
100 :param a: act
101 :param r: reward
102 :return:
103 """
104 self.state_buffer.append(np.array([s], np.float32))
105 self.action_buffer.append(a)
106 self.reward_buffer.append(r)
107
108 def learn(self):
109 """
110 update policy parameters via stochastic gradient ascent
111 :return: None
112 """
113 discounted_reward_buffer_norm = self._discount_and_norm_rewards()
114
115 with tf.GradientTape() as tape:
116 _logits = self.model(np.vstack(self.state_buffer))

Callers 1

tutorial_PG.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…