MCPcopy
hub / github.com/Lightricks/ComfyUI-LTXVideo / predict_noise

Method predict_noise

stg.py:377–490  ·  view source on GitHub ↗
(
        self,
        x: torch.Tensor,
        timestep: torch.Tensor,
        model_options: dict = {},
        seed=None,
    )

Source from the content-addressed store, hash-verified

375 self.inner_set_conds({"positive": positive, "negative": negative})
376
377 def predict_noise(
378 self,
379 x: torch.Tensor,
380 timestep: torch.Tensor,
381 model_options: dict = {},
382 seed=None,
383 ):
384 # CFG zero init - skipping steps with timestep bigger than given threshold.
385 if timestep > self.skip_steps_sigma_threshold:
386 return torch.zeros_like(x)
387
388 # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
389 # but we'd rather do a single batch of sampling pos, neg, and perturbed, so we call calc_cond_batch([perturbed,pos,neg]) directly
390 positive_cond = self.conds.get("positive", None)
391 negative_cond = self.conds.get("negative", None)
392
393 cfg_value, stg_scale, stg_rescale, stg_layer_skip_layer_indices = (
394 self.sigma_to_params_mapping(timestep)
395 )
396
397 if stg_layer_skip_layer_indices is not None:
398 self.stg_flag.skip_layers = stg_layer_skip_layer_indices
399 STGGuider.patch_model(self.model_patcher, self.stg_flag)
400
401 noise_pred_pos = comfy.samplers.calc_cond_batch(
402 self.inner_model,
403 [positive_cond],
404 x,
405 timestep,
406 model_options,
407 )[0]
408
409 noise_pred_neg = 0
410 noise_pred_perturbed = 0
411
412 if not math.isclose(cfg_value, 1.0) or (
413 self.apply_apg and not math.isclose(self.apg_cfg_scale, 1.0)
414 ):
415 noise_pred_neg = comfy.samplers.calc_cond_batch(
416 self.inner_model,
417 [negative_cond],
418 x,
419 timestep,
420 model_options,
421 )[0]
422
423 if self.cfg_star_rescale:
424 batch_size = noise_pred_pos.shape[0]
425
426 positive_flat = noise_pred_pos.view(batch_size, -1)
427 negative_flat = noise_pred_neg.view(batch_size, -1)
428 dot_product = torch.sum(
429 positive_flat * negative_flat, dim=1, keepdim=True
430 )
431 squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
432 alpha = dot_product / squared_norm
433 noise_pred_neg = alpha * noise_pred_neg
434

Callers

nothing calls this directly

Calls 5

stgFunction · 0.85
apgFunction · 0.85
getMethod · 0.80
patch_modelMethod · 0.45

Tested by

no test coverage detected