Process a single video or directory of images.
(self, input_path, alpha_output_dir=None, dilate_radius=0, on_frame_complete=None)
| 104 | torch.cuda.ipc_collect() |
| 105 | |
| 106 | def process(self, input_path, alpha_output_dir=None, dilate_radius=0, on_frame_complete=None): |
| 107 | """ |
| 108 | Process a single video or directory of images. |
| 109 | """ |
| 110 | input_path = Path(input_path) |
| 111 | file_name = input_path.stem |
| 112 | is_video = input_path.suffix.lower() in [".mp4", ".mkv", ".gif", ".mov", ".avi"] |
| 113 | |
| 114 | def get_frames(): |
| 115 | """Yields tuples of (image_numpy_array, output_file_name)""" |
| 116 | if is_video: |
| 117 | cap = cv2.VideoCapture(str(input_path)) |
| 118 | count = 0 |
| 119 | while True: |
| 120 | success, img = cap.read() |
| 121 | if not success: |
| 122 | break |
| 123 | yield img, f"{file_name}_alpha_{count:05d}.png" |
| 124 | count += 1 |
| 125 | cap.release() |
| 126 | else: |
| 127 | image_files = sorted( |
| 128 | [ |
| 129 | f |
| 130 | for f in input_path.iterdir() |
| 131 | if f.is_file() and f.suffix.lower() in [".jpg", ".png", ".jpeg", ".exr"] |
| 132 | ] |
| 133 | ) |
| 134 | if not image_files: |
| 135 | logging.warning(f"No images found in {input_path}") |
| 136 | return |
| 137 | |
| 138 | # Setup EXR support once if needed |
| 139 | if "OPENCV_IO_ENABLE_OPENEXR" not in os.environ: |
| 140 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
| 141 | |
| 142 | for img_path in image_files: |
| 143 | img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED) |
| 144 | if img is None: |
| 145 | continue |
| 146 | # Keep original filename for image sequences |
| 147 | yield img, f"alphaSeq_{img_path.stem}.png" |
| 148 | |
| 149 | count = 0 |
| 150 | for image, out_name in get_frames(): |
| 151 | # Ensure correct conversion to RGB regardless of input format (EXR/PNG/JPG) |
| 152 | if len(image.shape) == 2: |
| 153 | image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
| 154 | elif image.shape[2] == 4: |
| 155 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) |
| 156 | else: |
| 157 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| 158 | |
| 159 | # EXR images load as float32. PIL expects uint8. Normalize if necessary. |
| 160 | if image_rgb.dtype != np.uint8: |
| 161 | image_rgb = cv2.normalize(image_rgb, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
| 162 | |
| 163 | pil_image = Image.fromarray(image_rgb) |
no test coverage detected