MCPcopy
hub / github.com/XPixelGroup/DiffBIR / make_schedule

Method make_schedule

diffbir/sampler/spaced_sampler.py:77–116  ·  view source on GitHub ↗
(self, num_steps: int)

Source from the content-addressed store, hash-verified

75 super().__init__(betas, parameterization, rescale_cfg)
76
77 def make_schedule(self, num_steps: int) -> None:
78 used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
79 betas = []
80 last_alpha_cumprod = 1.0
81 for i, alpha_cumprod in enumerate(self.training_alphas_cumprod):
82 if i in used_timesteps:
83 betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 last_alpha_cumprod = alpha_cumprod
85 self.timesteps = np.array(
86 sorted(list(used_timesteps)), dtype=np.int32
87 ) # e.g. [0, 10, 20, ...]
88
89 betas = np.array(betas, dtype=np.float64)
90 alphas = 1.0 - betas
91 alphas_cumprod = np.cumprod(alphas, axis=0)
92 alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
93
94 sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
95 sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
96 posterior_variance = (
97 betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
98 )
99 posterior_log_variance_clipped = np.log(
100 np.append(posterior_variance[1], posterior_variance[1:])
101 )
102 posterior_mean_coef1 = (
103 betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
104 )
105 posterior_mean_coef2 = (
106 (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
107 )
108
109 self.register("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod))
110 self.register("sqrt_one_minus_alphas_cumprod", np.sqrt(1 - alphas_cumprod))
111 self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
112 self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
113 self.register("posterior_variance", posterior_variance)
114 self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
115 self.register("posterior_mean_coef1", posterior_mean_coef1)
116 self.register("posterior_mean_coef2", posterior_mean_coef2)
117
118 def q_posterior_mean_variance(
119 self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor

Callers 1

sampleMethod · 0.95

Calls 2

space_timestepsFunction · 0.85
registerMethod · 0.45

Tested by

no test coverage detected