MCPcopy
hub / github.com/tensorlayer/TensorLayer / train

Method train

examples/reinforcement_learning/tutorial_C51.py:228–253  ·  view source on GitHub ↗
(self, b_o, b_a, b_r, b_o_, b_d)

Source from the content-addressed store, hash-verified

226 return self.qnet(obv)
227
228 def train(self, b_o, b_a, b_r, b_o_, b_d):
229 # TODO: move q_estimation in tf.function
230 b_dist_ = np.exp(self.targetqnet(b_o_).numpy())
231 b_a_ = (b_dist_ * vrange).sum(-1).argmax(1)
232 b_tzj = np.clip(reward_gamma * (1 - b_d[:, None]) * vrange[None, :] + b_r[:, None], min_value, max_value)
233 b_i = (b_tzj - min_value) / deltaz
234 b_l = np.floor(b_i).astype('int64')
235 b_u = np.ceil(b_i).astype('int64')
236 templ = b_dist_[range(batch_size), b_a_, :] * (b_u - b_i)
237 tempu = b_dist_[range(batch_size), b_a_, :] * (b_i - b_l)
238 b_m = np.zeros((batch_size, atom_num))
239 # TODO: aggregate value by index and batch update (scatter_add)
240 for j in range(batch_size):
241 for k in range(atom_num):
242 b_m[j][b_l[j][k]] += templ[j][k]
243 b_m[j][b_u[j][k]] += tempu[j][k]
244 b_m = tf.convert_to_tensor(b_m, dtype='float32')
245 b_index = np.stack([range(batch_size), b_a], 1)
246 b_index = tf.convert_to_tensor(b_index, 'int64')
247
248 self._train_func(b_o, b_index, b_m)
249
250 self.niter += 1
251 if self.niter % target_q_update_freq == 0:
252 sync(self.qnet, self.targetqnet)
253 self.save(args.save_path)
254
255 def save(self, path):
256 if path is None:

Callers 15

mainFunction · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45

Calls 4

_train_funcMethod · 0.95
saveMethod · 0.95
sumMethod · 0.80
syncFunction · 0.70

Tested by

no test coverage detected