| 132 | return out |
| 133 | |
| 134 | class ControlNet(ControlBase): |
| 135 | def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): |
| 136 | super().__init__(device) |
| 137 | self.control_model = control_model |
| 138 | self.load_device = load_device |
| 139 | self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device()) |
| 140 | self.global_average_pooling = global_average_pooling |
| 141 | self.model_sampling_current = None |
| 142 | self.manual_cast_dtype = manual_cast_dtype |
| 143 | |
| 144 | def get_control(self, x_noisy, t, cond, batched_number): |
| 145 | control_prev = None |
| 146 | if self.previous_controlnet is not None: |
| 147 | control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) |
| 148 | |
| 149 | if self.timestep_range is not None: |
| 150 | if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: |
| 151 | if control_prev is not None: |
| 152 | return control_prev |
| 153 | else: |
| 154 | return None |
| 155 | |
| 156 | dtype = self.control_model.dtype |
| 157 | if self.manual_cast_dtype is not None: |
| 158 | dtype = self.manual_cast_dtype |
| 159 | |
| 160 | output_dtype = x_noisy.dtype |
| 161 | if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: |
| 162 | if self.cond_hint is not None: |
| 163 | del self.cond_hint |
| 164 | self.cond_hint = None |
| 165 | self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) |
| 166 | if x_noisy.shape[0] != self.cond_hint.shape[0]: |
| 167 | self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) |
| 168 | |
| 169 | context = cond['c_crossattn'] |
| 170 | y = cond.get('y', None) |
| 171 | if y is not None: |
| 172 | y = y.to(dtype) |
| 173 | timestep = self.model_sampling_current.timestep(t) |
| 174 | x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) |
| 175 | |
| 176 | control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) |
| 177 | return self.control_merge(None, control, control_prev, output_dtype) |
| 178 | |
| 179 | def copy(self): |
| 180 | c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) |
| 181 | self.copy_to(c) |
| 182 | return c |
| 183 | |
| 184 | def get_models(self): |
| 185 | out = super().get_models() |
| 186 | out.append(self.control_model_wrapped) |
| 187 | return out |
| 188 | |
| 189 | def pre_run(self, model, percent_to_timestep_function): |
| 190 | super().pre_run(model, percent_to_timestep_function) |
| 191 | self.model_sampling_current = model.model_sampling |
no outgoing calls
no test coverage detected