(
self,
x: torch.Tensor,
timestep: torch.Tensor,
model_options: dict = {},
seed=None,
)
| 375 | self.inner_set_conds({"positive": positive, "negative": negative}) |
| 376 | |
| 377 | def predict_noise( |
| 378 | self, |
| 379 | x: torch.Tensor, |
| 380 | timestep: torch.Tensor, |
| 381 | model_options: dict = {}, |
| 382 | seed=None, |
| 383 | ): |
| 384 | # CFG zero init - skipping steps with timestep bigger than given threshold. |
| 385 | if timestep > self.skip_steps_sigma_threshold: |
| 386 | return torch.zeros_like(x) |
| 387 | |
| 388 | # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg |
| 389 | # but we'd rather do a single batch of sampling pos, neg, and perturbed, so we call calc_cond_batch([perturbed,pos,neg]) directly |
| 390 | positive_cond = self.conds.get("positive", None) |
| 391 | negative_cond = self.conds.get("negative", None) |
| 392 | |
| 393 | cfg_value, stg_scale, stg_rescale, stg_layer_skip_layer_indices = ( |
| 394 | self.sigma_to_params_mapping(timestep) |
| 395 | ) |
| 396 | |
| 397 | if stg_layer_skip_layer_indices is not None: |
| 398 | self.stg_flag.skip_layers = stg_layer_skip_layer_indices |
| 399 | STGGuider.patch_model(self.model_patcher, self.stg_flag) |
| 400 | |
| 401 | noise_pred_pos = comfy.samplers.calc_cond_batch( |
| 402 | self.inner_model, |
| 403 | [positive_cond], |
| 404 | x, |
| 405 | timestep, |
| 406 | model_options, |
| 407 | )[0] |
| 408 | |
| 409 | noise_pred_neg = 0 |
| 410 | noise_pred_perturbed = 0 |
| 411 | |
| 412 | if not math.isclose(cfg_value, 1.0) or ( |
| 413 | self.apply_apg and not math.isclose(self.apg_cfg_scale, 1.0) |
| 414 | ): |
| 415 | noise_pred_neg = comfy.samplers.calc_cond_batch( |
| 416 | self.inner_model, |
| 417 | [negative_cond], |
| 418 | x, |
| 419 | timestep, |
| 420 | model_options, |
| 421 | )[0] |
| 422 | |
| 423 | if self.cfg_star_rescale: |
| 424 | batch_size = noise_pred_pos.shape[0] |
| 425 | |
| 426 | positive_flat = noise_pred_pos.view(batch_size, -1) |
| 427 | negative_flat = noise_pred_neg.view(batch_size, -1) |
| 428 | dot_product = torch.sum( |
| 429 | positive_flat * negative_flat, dim=1, keepdim=True |
| 430 | ) |
| 431 | squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 |
| 432 | alpha = dot_product / squared_norm |
| 433 | noise_pred_neg = alpha * noise_pred_neg |
| 434 |
nothing calls this directly
no test coverage detected