MCPcopy
hub / github.com/lllyasviel/Fooocus / ControlNet

Class ControlNet

ldm_patched/modules/controlnet.py:134–195  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

132 return out
133
134class ControlNet(ControlBase):
135 def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
136 super().__init__(device)
137 self.control_model = control_model
138 self.load_device = load_device
139 self.control_model_wrapped = ldm_patched.modules.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=ldm_patched.modules.model_management.unet_offload_device())
140 self.global_average_pooling = global_average_pooling
141 self.model_sampling_current = None
142 self.manual_cast_dtype = manual_cast_dtype
143
144 def get_control(self, x_noisy, t, cond, batched_number):
145 control_prev = None
146 if self.previous_controlnet is not None:
147 control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
148
149 if self.timestep_range is not None:
150 if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
151 if control_prev is not None:
152 return control_prev
153 else:
154 return None
155
156 dtype = self.control_model.dtype
157 if self.manual_cast_dtype is not None:
158 dtype = self.manual_cast_dtype
159
160 output_dtype = x_noisy.dtype
161 if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
162 if self.cond_hint is not None:
163 del self.cond_hint
164 self.cond_hint = None
165 self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
166 if x_noisy.shape[0] != self.cond_hint.shape[0]:
167 self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
168
169 context = cond['c_crossattn']
170 y = cond.get('y', None)
171 if y is not None:
172 y = y.to(dtype)
173 timestep = self.model_sampling_current.timestep(t)
174 x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
175
176 control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
177 return self.control_merge(None, control, control_prev, output_dtype)
178
179 def copy(self):
180 c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
181 self.copy_to(c)
182 return c
183
184 def get_models(self):
185 out = super().get_models()
186 out.append(self.control_model_wrapped)
187 return out
188
189 def pre_run(self, model, percent_to_timestep_function):
190 super().pre_run(model, percent_to_timestep_function)
191 self.model_sampling_current = model.model_sampling

Callers 2

copyMethod · 0.70
load_controlnetFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected