Estimate max safe worker processes based on available GPU/system memory. Each worker spawns a separate process that loads its own copy of the model plus a CUDA context. This estimates how many can fit in VRAM.
(model_name="u2net", gpu_batchsize=2)
| 48 | |
| 49 | |
| 50 | def max_workers(model_name="u2net", gpu_batchsize=2): |
| 51 | """Estimate max safe worker processes based on available GPU/system memory. |
| 52 | |
| 53 | Each worker spawns a separate process that loads its own copy of the model |
| 54 | plus a CUDA context. This estimates how many can fit in VRAM. |
| 55 | """ |
| 56 | if torch.cuda.is_available(): |
| 57 | try: |
| 58 | total_mem = torch.cuda.get_device_properties(0).total_memory |
| 59 | except Exception: |
| 60 | return 1 |
| 61 | |
| 62 | # Per-worker VRAM estimate: |
| 63 | # CUDA context per process: ~400MB |
| 64 | # Model weights (float32): ~175MB (u2net/human_seg), ~5MB (u2netp) |
| 65 | # JIT traced copy: same as model weights |
| 66 | # Batch inference tensors: ~30MB per frame in batch |
| 67 | if model_name == "u2netp": |
| 68 | model_bytes = 5 * 1024 * 1024 |
| 69 | else: |
| 70 | model_bytes = 175 * 1024 * 1024 |
| 71 | |
| 72 | per_worker = ( |
| 73 | 400 * 1024 * 1024 # CUDA context overhead |
| 74 | + model_bytes * 2 # model + JIT trace |
| 75 | + gpu_batchsize * 30 * 1024 * 1024 # inference tensors |
| 76 | ) |
| 77 | |
| 78 | # Reserve 512MB for OS/driver/display |
| 79 | usable = total_mem - 512 * 1024 * 1024 |
| 80 | calculated = max(1, int(usable // per_worker)) |
| 81 | return calculated |
| 82 | |
| 83 | # CPU/MPS: limit by CPU cores (inference is compute-bound) |
| 84 | cpu_count = os.cpu_count() or 2 |
| 85 | return max(1, cpu_count // 2) |
| 86 | |
| 87 | class Net(torch.nn.Module): |
| 88 | def __init__(self, model_name): |