MCPcopy
hub / github.com/ali-vilab/AnyDoor / predict

Method predict

predict.py:225–267  ·  view source on GitHub ↗

Run a single prediction on the model

(
        self,
        reference_image_path: Path = Input(description="Source Image"),
        reference_image_mask: Path = Input(description="Source Image"),
        bg_image_path: Path = Input(description="Target Image"),
        bg_mask_path: Path = Input(description="Target Image mask"),
        control_strength: float = Input(description="Control Strength", default=1.0, ge=0.0, le=2.0),
        steps: int = Input(description="Steps", default=50, ge=1, le=100),
        guidance_scale: float = Input(description="Guidance Scale", default=4.5, ge=0.1, le=30.0),
        enable_shape_control: bool = Input(description="Enable Shape Control", default=False),
        seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None),
    )

Source from the content-addressed store, hash-verified

223 return gen_image
224
225 def predict(
226 self,
227 reference_image_path: Path = Input(description="Source Image"),
228 reference_image_mask: Path = Input(description="Source Image"),
229 bg_image_path: Path = Input(description="Target Image"),
230 bg_mask_path: Path = Input(description="Target Image mask"),
231 control_strength: float = Input(description="Control Strength", default=1.0, ge=0.0, le=2.0),
232 steps: int = Input(description="Steps", default=50, ge=1, le=100),
233 guidance_scale: float = Input(description="Guidance Scale", default=4.5, ge=0.1, le=30.0),
234 enable_shape_control: bool = Input(description="Enable Shape Control", default=False),
235 seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None),
236 ) -> Path:
237 """Run a single prediction on the model"""
238 if seed is None:
239 seed = int.from_bytes(os.urandom(4), "big")
240 print(f"Using seed: {seed}")
241
242 save_path = "/tmp/output.png"
243 image = cv2.imread(str(reference_image_path), cv2.IMREAD_UNCHANGED)
244 if image.shape[2] == 1:
245 image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
246 elif image.shape[2] == 4:
247 image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
248 ref_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
249 ref_mask = (cv2.imread(str(reference_image_mask))[:,:,-1] > 128).astype(np.uint8)
250
251 # background image
252 back_image = cv2.imread(str(bg_image_path)).astype(np.uint8)
253 back_image = cv2.cvtColor(back_image, cv2.COLOR_BGR2RGB)
254
255 # background mask
256 tar_mask = cv2.imread(str(bg_mask_path))[:,:,0] > 128
257 tar_mask = tar_mask.astype(np.uint8)
258
259 gen_image = self.inference_single_image(
260 ref_image,ref_mask, back_image.copy(), tar_mask,
261 control_strength, steps, guidance_scale, seed, enable_shape_control)
262 h,w = back_image.shape[0], back_image.shape[0]
263 ref_image = cv2.resize(ref_image, (w,h))
264 vis_image = cv2.hconcat([gen_image])
265 cv2.imwrite(save_path, vis_image [:,:,::-1])
266
267 return Path(save_path)

Callers

nothing calls this directly

Calls 2

printFunction · 0.85

Tested by

no test coverage detected