MCPcopy
hub / github.com/google/tf-quant-finance / _for_loop

Function _for_loop

tf_quant_finance/models/milstein_sampling.py:427–473  ·  view source on GitHub ↗

Sample paths using custom for_loop.

(*, dim, steps_num, current_state, drift_fn, volatility_fn,
              grad_volatility_fn, wiener_mean, watch_params, num_samples, times,
              dt, sqrt_dt, time_indices, keep_mask, random_type, seed,
              normal_draws, input_gradients, stratonovich_order,
              aux_normal_draws)

Source from the content-addressed store, hash-verified

425
426
427def _for_loop(*, dim, steps_num, current_state, drift_fn, volatility_fn,
428 grad_volatility_fn, wiener_mean, watch_params, num_samples, times,
429 dt, sqrt_dt, time_indices, keep_mask, random_type, seed,
430 normal_draws, input_gradients, stratonovich_order,
431 aux_normal_draws):
432 """Sample paths using custom for_loop."""
433 num_time_points = time_indices.shape.as_list()[-1]
434 if num_time_points == 1:
435 iter_nums = steps_num
436 else:
437 iter_nums = time_indices
438
439 def step_fn(i, current_state):
440 # Unpack current_state
441 current_state = current_state[0]
442 _, _, next_state, _ = _milstein_step(
443 dim=dim,
444 i=i,
445 written_count=0,
446 current_state=current_state,
447 result=tf.expand_dims(current_state, axis=1),
448 drift_fn=drift_fn,
449 volatility_fn=volatility_fn,
450 grad_volatility_fn=grad_volatility_fn,
451 wiener_mean=wiener_mean,
452 num_samples=num_samples,
453 times=times,
454 dt=dt,
455 sqrt_dt=sqrt_dt,
456 keep_mask=keep_mask,
457 random_type=random_type,
458 seed=seed,
459 normal_draws=normal_draws,
460 input_gradients=input_gradients,
461 stratonovich_order=stratonovich_order,
462 aux_normal_draws=aux_normal_draws,
463 record_samples=False)
464 return [next_state]
465
466 result = custom_loops.for_loop(
467 body_fn=step_fn,
468 initial_state=[current_state],
469 params=watch_params,
470 num_iterations=iter_nums)[0]
471 if num_time_points == 1:
472 return tf.expand_dims(result, axis=1)
473 return tf.transpose(result, (1, 0, 2))
474
475
476def _outer_prod(v1, v2):

Callers 1

_sampleFunction · 0.70

Calls 2

expand_dimsMethod · 0.80
transposeMethod · 0.80

Tested by

no test coverage detected