(
self,
x: torch.Tensor,
timestep: torch.Tensor,
model_options: dict = {},
seed=None,
)
| 816 | self.inner_set_conds({"positive": positive, "negative": negative}) |
| 817 | |
| 818 | def predict_noise( |
| 819 | self, |
| 820 | x: torch.Tensor, |
| 821 | timestep: torch.Tensor, |
| 822 | model_options: dict = {}, |
| 823 | seed=None, |
| 824 | ): |
| 825 | # CFG zero init - skipping steps with timestep bigger than given threshold. |
| 826 | |
| 827 | # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg |
| 828 | # but we'd rather do a single batch of sampling pos, neg, and perturbed, so we call calc_cond_batch([perturbed,pos,neg]) directly |
| 829 | positive_cond = self.conds.get("positive", None) |
| 830 | negative_cond = self.conds.get("negative", None) |
| 831 | |
| 832 | noise_pred_pos = comfy.samplers.calc_cond_batch( |
| 833 | self.inner_model, |
| 834 | [positive_cond], |
| 835 | x, |
| 836 | timestep, |
| 837 | model_options, |
| 838 | )[0] |
| 839 | |
| 840 | if ( |
| 841 | self.previous_timestep is not None |
| 842 | and timestep.item() > self.previous_timestep |
| 843 | ): |
| 844 | print("Resetting momentum buffer") |
| 845 | self.momentum_buffer = MomentumBuffer(self.momentum_coefficient) |
| 846 | |
| 847 | noise_pred_neg = 0 |
| 848 | if not math.isclose(self.cfg_scale, 1.0): |
| 849 | noise_pred_neg = comfy.samplers.calc_cond_batch( |
| 850 | self.inner_model, |
| 851 | [negative_cond], |
| 852 | x, |
| 853 | timestep, |
| 854 | model_options, |
| 855 | )[0] |
| 856 | |
| 857 | apg_result = apg( |
| 858 | noise_pred_pos, |
| 859 | noise_pred_neg, |
| 860 | self.cfg_scale, |
| 861 | self.eta, |
| 862 | self.norm_threshold, |
| 863 | ) |
| 864 | |
| 865 | # normally this would be done in cfg_function, but we skipped |
| 866 | # that for efficiency: we can compute the noise predictions in |
| 867 | # a single call to calc_cond_batch() (rather than two) |
| 868 | # so we replicate the hook here |
| 869 | for fn in model_options.get("sampler_post_cfg_function", []): |
| 870 | args = { |
| 871 | "denoised": apg_result, |
| 872 | "cond": positive_cond, |
| 873 | "uncond": negative_cond, |
| 874 | "model": self.inner_model, |
| 875 | "uncond_denoised": noise_pred_neg, |
nothing calls this directly
no test coverage detected