Run a single prediction on the model
(
self,
reference_image_path: Path = Input(description="Source Image"),
reference_image_mask: Path = Input(description="Source Image"),
bg_image_path: Path = Input(description="Target Image"),
bg_mask_path: Path = Input(description="Target Image mask"),
control_strength: float = Input(description="Control Strength", default=1.0, ge=0.0, le=2.0),
steps: int = Input(description="Steps", default=50, ge=1, le=100),
guidance_scale: float = Input(description="Guidance Scale", default=4.5, ge=0.1, le=30.0),
enable_shape_control: bool = Input(description="Enable Shape Control", default=False),
seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None),
)
| 223 | return gen_image |
| 224 | |
| 225 | def predict( |
| 226 | self, |
| 227 | reference_image_path: Path = Input(description="Source Image"), |
| 228 | reference_image_mask: Path = Input(description="Source Image"), |
| 229 | bg_image_path: Path = Input(description="Target Image"), |
| 230 | bg_mask_path: Path = Input(description="Target Image mask"), |
| 231 | control_strength: float = Input(description="Control Strength", default=1.0, ge=0.0, le=2.0), |
| 232 | steps: int = Input(description="Steps", default=50, ge=1, le=100), |
| 233 | guidance_scale: float = Input(description="Guidance Scale", default=4.5, ge=0.1, le=30.0), |
| 234 | enable_shape_control: bool = Input(description="Enable Shape Control", default=False), |
| 235 | seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None), |
| 236 | ) -> Path: |
| 237 | """Run a single prediction on the model""" |
| 238 | if seed is None: |
| 239 | seed = int.from_bytes(os.urandom(4), "big") |
| 240 | print(f"Using seed: {seed}") |
| 241 | |
| 242 | save_path = "/tmp/output.png" |
| 243 | image = cv2.imread(str(reference_image_path), cv2.IMREAD_UNCHANGED) |
| 244 | if image.shape[2] == 1: |
| 245 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
| 246 | elif image.shape[2] == 4: |
| 247 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) |
| 248 | ref_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| 249 | ref_mask = (cv2.imread(str(reference_image_mask))[:,:,-1] > 128).astype(np.uint8) |
| 250 | |
| 251 | # background image |
| 252 | back_image = cv2.imread(str(bg_image_path)).astype(np.uint8) |
| 253 | back_image = cv2.cvtColor(back_image, cv2.COLOR_BGR2RGB) |
| 254 | |
| 255 | # background mask |
| 256 | tar_mask = cv2.imread(str(bg_mask_path))[:,:,0] > 128 |
| 257 | tar_mask = tar_mask.astype(np.uint8) |
| 258 | |
| 259 | gen_image = self.inference_single_image( |
| 260 | ref_image,ref_mask, back_image.copy(), tar_mask, |
| 261 | control_strength, steps, guidance_scale, seed, enable_shape_control) |
| 262 | h,w = back_image.shape[0], back_image.shape[0] |
| 263 | ref_image = cv2.resize(ref_image, (w,h)) |
| 264 | vis_image = cv2.hconcat([gen_image]) |
| 265 | cv2.imwrite(save_path, vis_image [:,:,::-1]) |
| 266 | |
| 267 | return Path(save_path) |
nothing calls this directly
no test coverage detected