returns a sampling function with given ODE settings Args: - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 - num_steps: - fixed solver (Euler, Heun): the actual number of integration steps performed - adaptive solver (Do
(
self,
*,
sampling_method="dopri5",
num_steps=50,
atol=1e-6,
rtol=1e-3,
reverse=False,
do_shift=False,
time_shifting_factor=None,
)
| 398 | |
| 399 | |
| 400 | def sample_ode( |
| 401 | self, |
| 402 | *, |
| 403 | sampling_method="dopri5", |
| 404 | num_steps=50, |
| 405 | atol=1e-6, |
| 406 | rtol=1e-3, |
| 407 | reverse=False, |
| 408 | do_shift=False, |
| 409 | time_shifting_factor=None, |
| 410 | ): |
| 411 | """returns a sampling function with given ODE settings |
| 412 | Args: |
| 413 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 |
| 414 | - num_steps: |
| 415 | - fixed solver (Euler, Heun): the actual number of integration steps performed |
| 416 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation |
| 417 | - atol: absolute error tolerance for the solver |
| 418 | - rtol: relative error tolerance for the solver |
| 419 | """ |
| 420 | |
| 421 | # for flux |
| 422 | drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) |
| 423 | |
| 424 | t0, t1 = self.transport.check_interval( |
| 425 | self.transport.train_eps, |
| 426 | self.transport.sample_eps, |
| 427 | sde=False, |
| 428 | eval=True, |
| 429 | reverse=reverse, |
| 430 | last_step_size=0.0, |
| 431 | ) |
| 432 | |
| 433 | _ode = ode( |
| 434 | drift=drift, |
| 435 | t0=t0, |
| 436 | t1=t1, |
| 437 | sampler_type=sampling_method, |
| 438 | num_steps=num_steps, |
| 439 | atol=atol, |
| 440 | rtol=rtol, |
| 441 | do_shift=do_shift, |
| 442 | time_shifting_factor=time_shifting_factor, |
| 443 | ) |
| 444 | |
| 445 | return _ode.sample |
| 446 | |
| 447 | def sample_ode_likelihood( |
| 448 | self, |
no test coverage detected