Delegate to MLX engine, then normalize output to Torch contract. ``screen_channel`` is accepted for API parity with the Torch engine but only ``1`` (green) is supported here — the MLX backend has no blue checkpoint yet, so the despill in ``_wrap_mlx_output`` is hard-wired to
(
self,
image,
mask_linear,
refiner_scale=1.0,
input_is_linear=False,
fg_is_straight=True,
despill_strength=1.0,
auto_despeckle=True,
despeckle_size=400,
screen_channel: int = 1,
**_kwargs,
)
| 376 | logger.info("MLX adapter active: despill and despeckle are handled by the adapter layer, not native MLX") |
| 377 | |
| 378 | def process_frame( |
| 379 | self, |
| 380 | image, |
| 381 | mask_linear, |
| 382 | refiner_scale=1.0, |
| 383 | input_is_linear=False, |
| 384 | fg_is_straight=True, |
| 385 | despill_strength=1.0, |
| 386 | auto_despeckle=True, |
| 387 | despeckle_size=400, |
| 388 | screen_channel: int = 1, |
| 389 | **_kwargs, |
| 390 | ): |
| 391 | """Delegate to MLX engine, then normalize output to Torch contract. |
| 392 | |
| 393 | ``screen_channel`` is accepted for API parity with the Torch engine but |
| 394 | only ``1`` (green) is supported here — the MLX backend has no blue |
| 395 | checkpoint yet, so the despill in ``_wrap_mlx_output`` is hard-wired to |
| 396 | the green channel. Calling with ``screen_channel != 1`` is a programmer |
| 397 | error (the public ``create_engine`` rejects MLX + blue earlier); we |
| 398 | raise instead of silently returning a green-keyed result. |
| 399 | """ |
| 400 | if screen_channel != 1: |
| 401 | raise NotImplementedError( |
| 402 | f"_MLXEngineAdapter does not support screen_channel={screen_channel}. " |
| 403 | "MLX has no blue-screen checkpoint yet; use the Torch backend with " |
| 404 | "--screen-color blue, or wait for the MLX blue release." |
| 405 | ) |
| 406 | # MLX engine expects uint8 input — convert if float |
| 407 | if image.dtype != np.uint8: |
| 408 | image_u8 = (np.clip(image, 0.0, 1.0) * 255).astype(np.uint8) |
| 409 | else: |
| 410 | image_u8 = image |
| 411 | |
| 412 | if mask_linear.dtype != np.uint8: |
| 413 | mask_u8 = (np.clip(mask_linear, 0.0, 1.0) * 255).astype(np.uint8) |
| 414 | else: |
| 415 | mask_u8 = mask_linear |
| 416 | |
| 417 | # Squeeze mask to 2D for MLX (it validates [H,W] or [H,W,1]) |
| 418 | if mask_u8.ndim == 3: |
| 419 | mask_u8 = mask_u8[:, :, 0] |
| 420 | |
| 421 | raw = self._engine.process_frame( |
| 422 | image_u8, |
| 423 | mask_u8, |
| 424 | refiner_scale=refiner_scale, |
| 425 | input_is_linear=input_is_linear, |
| 426 | fg_is_straight=fg_is_straight, |
| 427 | despill_strength=0.0, # disable MLX stubs — adapter applies these |
| 428 | auto_despeckle=False, |
| 429 | despeckle_size=despeckle_size, |
| 430 | ) |
| 431 | |
| 432 | return _wrap_mlx_output(raw, despill_strength, auto_despeckle, despeckle_size) |
| 433 | |
| 434 | |
| 435 | DEFAULT_MLX_TILE_SIZE = 512 |