wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
| 75 | |
| 76 | |
| 77 | class Img2ImgDiscretizationWrapper: |
| 78 | """ |
| 79 | wraps a discretizer, and prunes the sigmas |
| 80 | params: |
| 81 | strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) |
| 82 | """ |
| 83 | |
| 84 | def __init__(self, discretization, strength: float = 1.0): |
| 85 | self.discretization = discretization |
| 86 | self.strength = strength |
| 87 | assert 0.0 <= self.strength <= 1.0 |
| 88 | |
| 89 | def __call__(self, *args, **kwargs): |
| 90 | # sigmas start large first, and decrease then |
| 91 | sigmas = self.discretization(*args, **kwargs) |
| 92 | print(f"sigmas after discretization, before pruning img2img: ", sigmas) |
| 93 | sigmas = torch.flip(sigmas, (0,)) |
| 94 | sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] |
| 95 | print("prune index:", max(int(self.strength * len(sigmas)), 1)) |
| 96 | sigmas = torch.flip(sigmas, (0,)) |
| 97 | print(f"sigmas after pruning: ", sigmas) |
| 98 | return sigmas |
| 99 | |
| 100 | |
| 101 | def do_sample( |