MCPcopy
hub / github.com/nikopueringer/CorridorKey / process_frame

Method process_frame

CorridorKeyModule/backend.py:378–432  ·  view source on GitHub ↗

Delegate to MLX engine, then normalize output to Torch contract. ``screen_channel`` is accepted for API parity with the Torch engine but only ``1`` (green) is supported here — the MLX backend has no blue checkpoint yet, so the despill in ``_wrap_mlx_output`` is hard-wired to

(
        self,
        image,
        mask_linear,
        refiner_scale=1.0,
        input_is_linear=False,
        fg_is_straight=True,
        despill_strength=1.0,
        auto_despeckle=True,
        despeckle_size=400,
        screen_channel: int = 1,
        **_kwargs,
    )

Source from the content-addressed store, hash-verified

376 logger.info("MLX adapter active: despill and despeckle are handled by the adapter layer, not native MLX")
377
378 def process_frame(
379 self,
380 image,
381 mask_linear,
382 refiner_scale=1.0,
383 input_is_linear=False,
384 fg_is_straight=True,
385 despill_strength=1.0,
386 auto_despeckle=True,
387 despeckle_size=400,
388 screen_channel: int = 1,
389 **_kwargs,
390 ):
391 """Delegate to MLX engine, then normalize output to Torch contract.
392
393 ``screen_channel`` is accepted for API parity with the Torch engine but
394 only ``1`` (green) is supported here — the MLX backend has no blue
395 checkpoint yet, so the despill in ``_wrap_mlx_output`` is hard-wired to
396 the green channel. Calling with ``screen_channel != 1`` is a programmer
397 error (the public ``create_engine`` rejects MLX + blue earlier); we
398 raise instead of silently returning a green-keyed result.
399 """
400 if screen_channel != 1:
401 raise NotImplementedError(
402 f"_MLXEngineAdapter does not support screen_channel={screen_channel}. "
403 "MLX has no blue-screen checkpoint yet; use the Torch backend with "
404 "--screen-color blue, or wait for the MLX blue release."
405 )
406 # MLX engine expects uint8 input — convert if float
407 if image.dtype != np.uint8:
408 image_u8 = (np.clip(image, 0.0, 1.0) * 255).astype(np.uint8)
409 else:
410 image_u8 = image
411
412 if mask_linear.dtype != np.uint8:
413 mask_u8 = (np.clip(mask_linear, 0.0, 1.0) * 255).astype(np.uint8)
414 else:
415 mask_u8 = mask_linear
416
417 # Squeeze mask to 2D for MLX (it validates [H,W] or [H,W,1])
418 if mask_u8.ndim == 3:
419 mask_u8 = mask_u8[:, :, 0]
420
421 raw = self._engine.process_frame(
422 image_u8,
423 mask_u8,
424 refiner_scale=refiner_scale,
425 input_is_linear=input_is_linear,
426 fg_is_straight=fg_is_straight,
427 despill_strength=0.0, # disable MLX stubs — adapter applies these
428 auto_despeckle=False,
429 despeckle_size=despeckle_size,
430 )
431
432 return _wrap_mlx_output(raw, despill_strength, auto_despeckle, despeckle_size)
433
434
435DEFAULT_MLX_TILE_SIZE = 512

Calls 1

_wrap_mlx_outputFunction · 0.85