Sample paths using custom for_loop.
(*, dim, steps_num, current_state, drift_fn, volatility_fn,
grad_volatility_fn, wiener_mean, watch_params, num_samples, times,
dt, sqrt_dt, time_indices, keep_mask, random_type, seed,
normal_draws, input_gradients, stratonovich_order,
aux_normal_draws)
| 425 | |
| 426 | |
| 427 | def _for_loop(*, dim, steps_num, current_state, drift_fn, volatility_fn, |
| 428 | grad_volatility_fn, wiener_mean, watch_params, num_samples, times, |
| 429 | dt, sqrt_dt, time_indices, keep_mask, random_type, seed, |
| 430 | normal_draws, input_gradients, stratonovich_order, |
| 431 | aux_normal_draws): |
| 432 | """Sample paths using custom for_loop.""" |
| 433 | num_time_points = time_indices.shape.as_list()[-1] |
| 434 | if num_time_points == 1: |
| 435 | iter_nums = steps_num |
| 436 | else: |
| 437 | iter_nums = time_indices |
| 438 | |
| 439 | def step_fn(i, current_state): |
| 440 | # Unpack current_state |
| 441 | current_state = current_state[0] |
| 442 | _, _, next_state, _ = _milstein_step( |
| 443 | dim=dim, |
| 444 | i=i, |
| 445 | written_count=0, |
| 446 | current_state=current_state, |
| 447 | result=tf.expand_dims(current_state, axis=1), |
| 448 | drift_fn=drift_fn, |
| 449 | volatility_fn=volatility_fn, |
| 450 | grad_volatility_fn=grad_volatility_fn, |
| 451 | wiener_mean=wiener_mean, |
| 452 | num_samples=num_samples, |
| 453 | times=times, |
| 454 | dt=dt, |
| 455 | sqrt_dt=sqrt_dt, |
| 456 | keep_mask=keep_mask, |
| 457 | random_type=random_type, |
| 458 | seed=seed, |
| 459 | normal_draws=normal_draws, |
| 460 | input_gradients=input_gradients, |
| 461 | stratonovich_order=stratonovich_order, |
| 462 | aux_normal_draws=aux_normal_draws, |
| 463 | record_samples=False) |
| 464 | return [next_state] |
| 465 | |
| 466 | result = custom_loops.for_loop( |
| 467 | body_fn=step_fn, |
| 468 | initial_state=[current_state], |
| 469 | params=watch_params, |
| 470 | num_iterations=iter_nums)[0] |
| 471 | if num_time_points == 1: |
| 472 | return tf.expand_dims(result, axis=1) |
| 473 | return tf.transpose(result, (1, 0, 2)) |
| 474 | |
| 475 | |
| 476 | def _outer_prod(v1, v2): |
no test coverage detected