Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D
(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None)
| 252 | return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| 253 | |
| 254 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): |
| 255 | """ |
| 256 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of |
| 257 | the initial x, x_0. |
| 258 | :param model: the model, which takes a signal and a batch of timesteps |
| 259 | as input. |
| 260 | :param x: the [N x C x ...] tensor at time t. |
| 261 | :param t: a 1-D Tensor of timesteps. |
| 262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. |
| 263 | :param denoised_fn: if not None, a function which applies to the |
| 264 | x_start prediction before it is used to sample. Applies before |
| 265 | clip_denoised. |
| 266 | :param model_kwargs: if not None, a dict of extra keyword arguments to |
| 267 | pass to the model. This can be used for conditioning. |
| 268 | :return: a dict with the following keys: |
| 269 | - 'mean': the model mean output. |
| 270 | - 'variance': the model variance output. |
| 271 | - 'log_variance': the log of 'variance'. |
| 272 | - 'pred_xstart': the prediction for x_0. |
| 273 | """ |
| 274 | if model_kwargs is None: |
| 275 | model_kwargs = {} |
| 276 | |
| 277 | B, F, C = x.shape[:3] |
| 278 | assert t.shape == (B,) |
| 279 | model_output = model(x, t, **model_kwargs) |
| 280 | # try: |
| 281 | # model_output = model_output.sample # for tav unet |
| 282 | # except: |
| 283 | # model_output = model(x, t, **model_kwargs) |
| 284 | if isinstance(model_output, tuple): |
| 285 | model_output, extra = model_output |
| 286 | else: |
| 287 | extra = None |
| 288 | |
| 289 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: |
| 290 | assert model_output.shape == (B, F, C * 2, *x.shape[3:]) |
| 291 | model_output, model_var_values = th.split(model_output, C, dim=2) |
| 292 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) |
| 293 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) |
| 294 | # The model_var_values is [-1, 1] for [min_var, max_var]. |
| 295 | frac = (model_var_values + 1) / 2 |
| 296 | model_log_variance = frac * max_log + (1 - frac) * min_log |
| 297 | model_variance = th.exp(model_log_variance) |
| 298 | else: |
| 299 | model_variance, model_log_variance = { |
| 300 | # for fixedlarge, we set the initial (log-)variance like so |
| 301 | # to get a better decoder log likelihood. |
| 302 | ModelVarType.FIXED_LARGE: ( |
| 303 | np.append(self.posterior_variance[1], self.betas[1:]), |
| 304 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), |
| 305 | ), |
| 306 | ModelVarType.FIXED_SMALL: ( |
| 307 | self.posterior_variance, |
| 308 | self.posterior_log_variance_clipped, |
| 309 | ), |
| 310 | }[self.model_var_type] |
| 311 | model_variance = _extract_into_tensor(model_variance, t, x.shape) |
no test coverage detected