MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / p_sample

Method p_sample

diffbir/sampler/spaced_sampler.py:162–184  ·  view source on GitHub ↗
(
        self,
        model: ControlLDM,
        x: torch.Tensor,
        model_t: torch.Tensor,
        t: torch.Tensor,
        cond: Dict[str, torch.Tensor],
        uncond: Optional[Dict[str, torch.Tensor]],
        cfg_scale: float,
    )

Source from the content-addressed store, hash-verified

160
161 @torch.no_grad()
162 def p_sample(
163 self,
164 model: ControlLDM,
165 x: torch.Tensor,
166 model_t: torch.Tensor,
167 t: torch.Tensor,
168 cond: Dict[str, torch.Tensor],
169 uncond: Optional[Dict[str, torch.Tensor]],
170 cfg_scale: float,
171 ) -> torch.Tensor:
172 # predict x_0
173 model_output = self.apply_model(model, x, model_t, cond, uncond, cfg_scale)
174 if self.parameterization == "eps":
175 pred_x0 = self._predict_xstart_from_eps(x, t, model_output)
176 else:
177 pred_x0 = self._predict_xstart_from_v(x, t, model_output)
178 # calculate mean and variance of next state
179 mean, variance = self.q_posterior_mean_variance(pred_x0, x, t)
180 # sample next state
181 noise = torch.randn_like(x)
182 nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
183 x_prev = mean + nonzero_mask * torch.sqrt(variance) * noise
184 return x_prev
185
186 @torch.no_grad()
187 def sample(

Callers 1

sampleMethod · 0.95

Calls 4

apply_modelMethod · 0.95

Tested by

no test coverage detected