MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / SpacedSampler

Class SpacedSampler

diffbir/sampler/spaced_sampler.py:67–245  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

65
66
67class SpacedSampler(Sampler):
68
69 def __init__(
70 self,
71 betas: np.ndarray,
72 parameterization: Literal["eps", "v"],
73 rescale_cfg: bool,
74 ) -> "SpacedSampler":
75 super().__init__(betas, parameterization, rescale_cfg)
76
77 def make_schedule(self, num_steps: int) -> None:
78 used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
79 betas = []
80 last_alpha_cumprod = 1.0
81 for i, alpha_cumprod in enumerate(self.training_alphas_cumprod):
82 if i in used_timesteps:
83 betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 last_alpha_cumprod = alpha_cumprod
85 self.timesteps = np.array(
86 sorted(list(used_timesteps)), dtype=np.int32
87 ) # e.g. [0, 10, 20, ...]
88
89 betas = np.array(betas, dtype=np.float64)
90 alphas = 1.0 - betas
91 alphas_cumprod = np.cumprod(alphas, axis=0)
92 alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
93
94 sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
95 sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
96 posterior_variance = (
97 betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
98 )
99 posterior_log_variance_clipped = np.log(
100 np.append(posterior_variance[1], posterior_variance[1:])
101 )
102 posterior_mean_coef1 = (
103 betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
104 )
105 posterior_mean_coef2 = (
106 (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
107 )
108
109 self.register("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod))
110 self.register("sqrt_one_minus_alphas_cumprod", np.sqrt(1 - alphas_cumprod))
111 self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
112 self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
113 self.register("posterior_variance", posterior_variance)
114 self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
115 self.register("posterior_mean_coef1", posterior_mean_coef1)
116 self.register("posterior_mean_coef2", posterior_mean_coef2)
117
118 def q_posterior_mean_variance(
119 self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor
120 ) -> Tuple[torch.Tensor]:
121 mean = (
122 extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
123 + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
124 )

Callers 2

mainFunction · 0.90
apply_cldmMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected