| 54 | |
| 55 | |
| 56 | class BiRefNetHandler: |
| 57 | def __init__(self, device="cpu", usage="General"): |
| 58 | self.device = device |
| 59 | |
| 60 | # Set resolution |
| 61 | if usage in ["General-Lite-2K"]: |
| 62 | self.resolution = (2560, 1440) |
| 63 | elif usage in ["General-reso_512"]: |
| 64 | self.resolution = (512, 512) |
| 65 | elif usage in ["General-HR", "Matting-HR"]: |
| 66 | self.resolution = (2048, 2048) |
| 67 | else: |
| 68 | if "-dynamic" in usage: |
| 69 | self.resolution = None |
| 70 | else: |
| 71 | self.resolution = (1024, 1024) |
| 72 | |
| 73 | repo_name = usage_to_weights_file[usage] |
| 74 | repo_id = f"ZhengPeng7/{repo_name}" |
| 75 | model_local_dir = os.path.join(base_folder, repo_name) |
| 76 | |
| 77 | snapshot_download( |
| 78 | repo_id=repo_id, |
| 79 | local_dir=model_local_dir, |
| 80 | local_dir_use_symlinks=False, # Ensures actual files are downloaded, not just symlinks to the cache |
| 81 | ) |
| 82 | |
| 83 | self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=False) |
| 84 | |
| 85 | self.birefnet.to(device) |
| 86 | self.birefnet.eval() |
| 87 | if half_precision: |
| 88 | self.birefnet.half() |
| 89 | |
| 90 | def cleanup(self): |
| 91 | """Explicitly clear model and release GPU memory.""" |
| 92 | # Delete the model reference |
| 93 | if hasattr(self, "birefnet"): |
| 94 | del self.birefnet |
| 95 | |
| 96 | # Clear Python garbage |
| 97 | import gc |
| 98 | |
| 99 | gc.collect() |
| 100 | |
| 101 | # Clear PyTorch CUDA cache |
| 102 | if torch.cuda.is_available(): |
| 103 | torch.cuda.empty_cache() |
| 104 | torch.cuda.ipc_collect() |
| 105 | |
| 106 | def process(self, input_path, alpha_output_dir=None, dilate_radius=0, on_frame_complete=None): |
| 107 | """ |
| 108 | Process a single video or directory of images. |
| 109 | """ |
| 110 | input_path = Path(input_path) |
| 111 | file_name = input_path.stem |
| 112 | is_video = input_path.suffix.lower() in [".mp4", ".mkv", ".gif", ".mov", ".avi"] |
| 113 | |