(self, b_o, b_a, b_r, b_o_, b_d)
| 226 | return self.qnet(obv) |
| 227 | |
| 228 | def train(self, b_o, b_a, b_r, b_o_, b_d): |
| 229 | # TODO: move q_estimation in tf.function |
| 230 | b_dist_ = np.exp(self.targetqnet(b_o_).numpy()) |
| 231 | b_a_ = (b_dist_ * vrange).sum(-1).argmax(1) |
| 232 | b_tzj = np.clip(reward_gamma * (1 - b_d[:, None]) * vrange[None, :] + b_r[:, None], min_value, max_value) |
| 233 | b_i = (b_tzj - min_value) / deltaz |
| 234 | b_l = np.floor(b_i).astype('int64') |
| 235 | b_u = np.ceil(b_i).astype('int64') |
| 236 | templ = b_dist_[range(batch_size), b_a_, :] * (b_u - b_i) |
| 237 | tempu = b_dist_[range(batch_size), b_a_, :] * (b_i - b_l) |
| 238 | b_m = np.zeros((batch_size, atom_num)) |
| 239 | # TODO: aggregate value by index and batch update (scatter_add) |
| 240 | for j in range(batch_size): |
| 241 | for k in range(atom_num): |
| 242 | b_m[j][b_l[j][k]] += templ[j][k] |
| 243 | b_m[j][b_u[j][k]] += tempu[j][k] |
| 244 | b_m = tf.convert_to_tensor(b_m, dtype='float32') |
| 245 | b_index = np.stack([range(batch_size), b_a], 1) |
| 246 | b_index = tf.convert_to_tensor(b_index, 'int64') |
| 247 | |
| 248 | self._train_func(b_o, b_index, b_m) |
| 249 | |
| 250 | self.niter += 1 |
| 251 | if self.niter % target_q_update_freq == 0: |
| 252 | sync(self.qnet, self.targetqnet) |
| 253 | self.save(args.save_path) |
| 254 | |
| 255 | def save(self, path): |
| 256 | if path is None: |
no test coverage detected