MCPcopy
hub / github.com/Stability-AI/generative-models / get_sampler

Function get_sampler

scripts/demo/streamlit_helpers.py:401–474  ·  view source on GitHub ↗
(sampler_name, steps, discretization_config, guider_config, key=1)

Source from the content-addressed store, hash-verified

399
400
401def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
402 if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
403 s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
404 s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
405 s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
406 s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
407
408 if sampler_name == "EulerEDMSampler":
409 sampler = EulerEDMSampler(
410 num_steps=steps,
411 discretization_config=discretization_config,
412 guider_config=guider_config,
413 s_churn=s_churn,
414 s_tmin=s_tmin,
415 s_tmax=s_tmax,
416 s_noise=s_noise,
417 verbose=True,
418 )
419 elif sampler_name == "HeunEDMSampler":
420 sampler = HeunEDMSampler(
421 num_steps=steps,
422 discretization_config=discretization_config,
423 guider_config=guider_config,
424 s_churn=s_churn,
425 s_tmin=s_tmin,
426 s_tmax=s_tmax,
427 s_noise=s_noise,
428 verbose=True,
429 )
430 elif (
431 sampler_name == "EulerAncestralSampler"
432 or sampler_name == "DPMPP2SAncestralSampler"
433 ):
434 s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
435 eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
436
437 if sampler_name == "EulerAncestralSampler":
438 sampler = EulerAncestralSampler(
439 num_steps=steps,
440 discretization_config=discretization_config,
441 guider_config=guider_config,
442 eta=eta,
443 s_noise=s_noise,
444 verbose=True,
445 )
446 elif sampler_name == "DPMPP2SAncestralSampler":
447 sampler = DPMPP2SAncestralSampler(
448 num_steps=steps,
449 discretization_config=discretization_config,
450 guider_config=guider_config,
451 eta=eta,
452 s_noise=s_noise,
453 verbose=True,
454 )
455 elif sampler_name == "DPMPP2MSampler":
456 sampler = DPMPP2MSampler(
457 num_steps=steps,
458 discretization_config=discretization_config,

Callers 1

init_samplingFunction · 0.85

Calls 6

EulerEDMSamplerClass · 0.90
HeunEDMSamplerClass · 0.90
DPMPP2MSamplerClass · 0.90

Tested by

no test coverage detected