(
self,
x: torch.Tensor,
timestep: torch.Tensor,
model_options: dict = {},
seed=None,
)
| 225 | self.inner_set_conds({"positive": positive, "negative": negative}) |
| 226 | |
| 227 | def predict_noise( |
| 228 | self, |
| 229 | x: torch.Tensor, |
| 230 | timestep: torch.Tensor, |
| 231 | model_options: dict = {}, |
| 232 | seed=None, |
| 233 | ): |
| 234 | # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg |
| 235 | # but we'd rather do a single batch of sampling pos, neg, and perturbed, so we call calc_cond_batch([perturbed,pos,neg]) directly |
| 236 | |
| 237 | positive_cond = self.conds.get("positive", None) |
| 238 | negative_cond = self.conds.get("negative", None) |
| 239 | |
| 240 | if model_options.get("sigma_to_params_mapping", None) is not None: |
| 241 | cfg_value, stg_scale, stg_layer_skip_layer_indices, stg_rescale = ( |
| 242 | model_options["sigma_to_params_mapping"](timestep) |
| 243 | ) |
| 244 | self.stg_flag.skip_layers = stg_layer_skip_layer_indices |
| 245 | self.patch_model(self.model_patcher, self.stg_flag) |
| 246 | |
| 247 | else: |
| 248 | cfg_value = self.cfg |
| 249 | stg_scale = self.stg_scale |
| 250 | stg_rescale = self.rescale_scale |
| 251 | |
| 252 | noise_pred_pos = comfy.samplers.calc_cond_batch( |
| 253 | self.inner_model, |
| 254 | [positive_cond], |
| 255 | x, |
| 256 | timestep, |
| 257 | model_options, |
| 258 | )[0] |
| 259 | |
| 260 | noise_pred_neg = 0 |
| 261 | noise_pred_perturbed = 0 |
| 262 | |
| 263 | if not math.isclose(cfg_value, 1.0): |
| 264 | noise_pred_neg = comfy.samplers.calc_cond_batch( |
| 265 | self.inner_model, |
| 266 | [negative_cond], |
| 267 | x, |
| 268 | timestep, |
| 269 | model_options, |
| 270 | )[0] |
| 271 | |
| 272 | if not math.isclose(stg_scale, 0.0): |
| 273 | try: |
| 274 | model_options["transformer_options"]["ptb_index"] = 0 |
| 275 | self.stg_flag.do_skip = True |
| 276 | noise_pred_perturbed = comfy.samplers.calc_cond_batch( |
| 277 | self.inner_model, |
| 278 | [positive_cond], |
| 279 | x, |
| 280 | timestep, |
| 281 | model_options, |
| 282 | )[0] |
| 283 | finally: |
| 284 | self.stg_flag.do_skip = False |
nothing calls this directly
no test coverage detected