(sampler_name, steps, discretization_config, guider_config, key=1)
| 399 | |
| 400 | |
| 401 | def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): |
| 402 | if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": |
| 403 | s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) |
| 404 | s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) |
| 405 | s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) |
| 406 | s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) |
| 407 | |
| 408 | if sampler_name == "EulerEDMSampler": |
| 409 | sampler = EulerEDMSampler( |
| 410 | num_steps=steps, |
| 411 | discretization_config=discretization_config, |
| 412 | guider_config=guider_config, |
| 413 | s_churn=s_churn, |
| 414 | s_tmin=s_tmin, |
| 415 | s_tmax=s_tmax, |
| 416 | s_noise=s_noise, |
| 417 | verbose=True, |
| 418 | ) |
| 419 | elif sampler_name == "HeunEDMSampler": |
| 420 | sampler = HeunEDMSampler( |
| 421 | num_steps=steps, |
| 422 | discretization_config=discretization_config, |
| 423 | guider_config=guider_config, |
| 424 | s_churn=s_churn, |
| 425 | s_tmin=s_tmin, |
| 426 | s_tmax=s_tmax, |
| 427 | s_noise=s_noise, |
| 428 | verbose=True, |
| 429 | ) |
| 430 | elif ( |
| 431 | sampler_name == "EulerAncestralSampler" |
| 432 | or sampler_name == "DPMPP2SAncestralSampler" |
| 433 | ): |
| 434 | s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) |
| 435 | eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) |
| 436 | |
| 437 | if sampler_name == "EulerAncestralSampler": |
| 438 | sampler = EulerAncestralSampler( |
| 439 | num_steps=steps, |
| 440 | discretization_config=discretization_config, |
| 441 | guider_config=guider_config, |
| 442 | eta=eta, |
| 443 | s_noise=s_noise, |
| 444 | verbose=True, |
| 445 | ) |
| 446 | elif sampler_name == "DPMPP2SAncestralSampler": |
| 447 | sampler = DPMPP2SAncestralSampler( |
| 448 | num_steps=steps, |
| 449 | discretization_config=discretization_config, |
| 450 | guider_config=guider_config, |
| 451 | eta=eta, |
| 452 | s_noise=s_noise, |
| 453 | verbose=True, |
| 454 | ) |
| 455 | elif sampler_name == "DPMPP2MSampler": |
| 456 | sampler = DPMPP2MSampler( |
| 457 | num_steps=steps, |
| 458 | discretization_config=discretization_config, |
no test coverage detected