MCPcopy
hub / github.com/Vchitect/Latte / p_mean_variance

Method p_mean_variance

diffusion/gaussian_diffusion.py:254–336  ·  view source on GitHub ↗

Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D

(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None)

Source from the content-addressed store, hash-verified

252 return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
254 def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255 """
256 Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257 the initial x, x_0.
258 :param model: the model, which takes a signal and a batch of timesteps
259 as input.
260 :param x: the [N x C x ...] tensor at time t.
261 :param t: a 1-D Tensor of timesteps.
262 :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263 :param denoised_fn: if not None, a function which applies to the
264 x_start prediction before it is used to sample. Applies before
265 clip_denoised.
266 :param model_kwargs: if not None, a dict of extra keyword arguments to
267 pass to the model. This can be used for conditioning.
268 :return: a dict with the following keys:
269 - 'mean': the model mean output.
270 - 'variance': the model variance output.
271 - 'log_variance': the log of 'variance'.
272 - 'pred_xstart': the prediction for x_0.
273 """
274 if model_kwargs is None:
275 model_kwargs = {}
276
277 B, F, C = x.shape[:3]
278 assert t.shape == (B,)
279 model_output = model(x, t, **model_kwargs)
280 # try:
281 # model_output = model_output.sample # for tav unet
282 # except:
283 # model_output = model(x, t, **model_kwargs)
284 if isinstance(model_output, tuple):
285 model_output, extra = model_output
286 else:
287 extra = None
288
289 if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
290 assert model_output.shape == (B, F, C * 2, *x.shape[3:])
291 model_output, model_var_values = th.split(model_output, C, dim=2)
292 min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
293 max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
294 # The model_var_values is [-1, 1] for [min_var, max_var].
295 frac = (model_var_values + 1) / 2
296 model_log_variance = frac * max_log + (1 - frac) * min_log
297 model_variance = th.exp(model_log_variance)
298 else:
299 model_variance, model_log_variance = {
300 # for fixedlarge, we set the initial (log-)variance like so
301 # to get a better decoder log likelihood.
302 ModelVarType.FIXED_LARGE: (
303 np.append(self.posterior_variance[1], self.betas[1:]),
304 np.log(np.append(self.posterior_variance[1], self.betas[1:])),
305 ),
306 ModelVarType.FIXED_SMALL: (
307 self.posterior_variance,
308 self.posterior_log_variance_clipped,
309 ),
310 }[self.model_var_type]
311 model_variance = _extract_into_tensor(model_variance, t, x.shape)

Callers 4

p_sampleMethod · 0.95
ddim_sampleMethod · 0.95
ddim_reverse_sampleMethod · 0.95
_vb_terms_bpdMethod · 0.95

Calls 4

_extract_into_tensorFunction · 0.85
appendMethod · 0.80

Tested by

no test coverage detected