| 357 | STGGuider.patch_model(model, self.stg_flag) |
| 358 | |
| 359 | def sigma_to_params_mapping(self, sigma): |
| 360 | # Find the closest higher sigma value and return corresponding cfg |
| 361 | higher_sigmas = [s for s in self.sigma_list if s >= sigma] |
| 362 | if not higher_sigmas: |
| 363 | closest_idx = -1 # Return last cfg if no higher sigma exists |
| 364 | else: |
| 365 | closest_higher = min(higher_sigmas) |
| 366 | closest_idx = self.sigma_list.index(closest_higher) |
| 367 | return ( |
| 368 | self.cfg_list[closest_idx], |
| 369 | self.stg_scale_list[closest_idx], |
| 370 | self.stg_rescale_list[closest_idx], |
| 371 | self.stg_layers_indices_list[closest_idx], |
| 372 | ) |
| 373 | |
| 374 | def set_conds(self, positive, negative): |
| 375 | self.inner_set_conds({"positive": positive, "negative": negative}) |