MCPcopy
hub / github.com/tensorlayer/TensorLayer / PPO

Class PPO

examples/reinforcement_learning/tutorial_PPO.py:70–253  ·  view source on GitHub ↗

PPO class

Source from the content-addressed store, hash-verified

68
69
70class PPO(object):
71 """
72 PPO class
73 """
74 def __init__(self, state_dim, action_dim, action_bound, method='clip'):
75 # critic
76 with tf.name_scope('critic'):
77 inputs = tl.layers.Input([None, state_dim], tf.float32, 'state')
78 layer = tl.layers.Dense(64, tf.nn.relu)(inputs)
79 layer = tl.layers.Dense(64, tf.nn.relu)(layer)
80 v = tl.layers.Dense(1)(layer)
81 self.critic = tl.models.Model(inputs, v)
82 self.critic.train()
83
84 # actor
85 with tf.name_scope('actor'):
86 inputs = tl.layers.Input([None, state_dim], tf.float32, 'state')
87 layer = tl.layers.Dense(64, tf.nn.relu)(inputs)
88 layer = tl.layers.Dense(64, tf.nn.relu)(layer)
89 a = tl.layers.Dense(action_dim, tf.nn.tanh)(layer)
90 mean = tl.layers.Lambda(lambda x: x * action_bound, name='lambda')(a)
91 logstd = tf.Variable(np.zeros(action_dim, dtype=np.float32))
92 self.actor = tl.models.Model(inputs, mean)
93 self.actor.trainable_weights.append(logstd)
94 self.actor.logstd = logstd
95 self.actor.train()
96
97 self.actor_opt = tf.optimizers.Adam(LR_A)
98 self.critic_opt = tf.optimizers.Adam(LR_C)
99
100 self.method = method
101 if method == 'penalty':
102 self.kl_target = KL_TARGET
103 self.lam = LAM
104 elif method == 'clip':
105 self.epsilon = EPSILON
106
107 self.state_buffer, self.action_buffer = [], []
108 self.reward_buffer, self.cumulative_reward_buffer = [], []
109 self.action_bound = action_bound
110
111 def train_actor(self, state, action, adv, old_pi):
112 """
113 Update policy network
114 :param state: state batch
115 :param action: action batch
116 :param adv: advantage batch
117 :param old_pi: old pi distribution
118 :return: kl_mean or None
119 """
120 with tf.GradientTape() as tape:
121 mean, std = self.actor(state), tf.exp(self.actor.logstd)
122 pi = tfp.distributions.Normal(mean, std)
123
124 ratio = tf.exp(pi.log_prob(action) - old_pi.log_prob(action))
125 surr = ratio * adv
126 if self.method == 'penalty': # ppo penalty
127 kl = tfp.distributions.kl_divergence(old_pi, pi)

Callers 1

tutorial_PPO.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected