Initializer for each worker process. Loads model (with tokenizers and duration estimator) onto a specific GPU via ``OmniVoice.from_pretrained()``.
(rank_queue, model_checkpoint, warmup=0)
| 205 | |
| 206 | |
| 207 | def process_init(rank_queue, model_checkpoint, warmup=0): |
| 208 | """Initializer for each worker process. |
| 209 | |
| 210 | Loads model (with tokenizers and duration estimator) onto a specific GPU |
| 211 | via ``OmniVoice.from_pretrained()``. |
| 212 | """ |
| 213 | global worker_model |
| 214 | |
| 215 | torch.set_num_threads(2) |
| 216 | torch.set_num_interop_threads(2) |
| 217 | |
| 218 | formatter = ( |
| 219 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " |
| 220 | "[Worker %(process)d] %(message)s" |
| 221 | ) |
| 222 | logging.basicConfig(format=formatter, level=logging.INFO, force=True) |
| 223 | |
| 224 | rank = rank_queue.get() |
| 225 | device_type, device_id = rank |
| 226 | if device_type == "cpu": |
| 227 | worker_device = "cpu" |
| 228 | elif device_type == "mps": |
| 229 | worker_device = "mps" |
| 230 | else: |
| 231 | worker_device = f"cuda:{device_id}" |
| 232 | |
| 233 | logging.info(f"Initializing worker on device: {worker_device}") |
| 234 | |
| 235 | worker_model = OmniVoice.from_pretrained( |
| 236 | model_checkpoint, |
| 237 | device_map=worker_device, |
| 238 | dtype=torch.float16, |
| 239 | ) |
| 240 | |
| 241 | if warmup > 0: |
| 242 | logging.info(f"Running {warmup} warmup iterations on {worker_device}") |
| 243 | dummy_ref_audio = ( |
| 244 | torch.randn(1, SAMPLING_RATE), |
| 245 | SAMPLING_RATE, |
| 246 | ) # 1s dummy audio |
| 247 | for i in range(warmup): |
| 248 | worker_model.generate( |
| 249 | text=["hello"], |
| 250 | language=["en"], |
| 251 | ref_audio=[dummy_ref_audio], |
| 252 | ref_text=["hello"], |
| 253 | ) |
| 254 | logging.info(f"Warmup complete on {worker_device}") |
| 255 | |
| 256 | logging.info(f"Worker on {worker_device} initialized successfully.") |
| 257 | |
| 258 | |
| 259 | def estimate_sample_total_duration( |
nothing calls this directly
no test coverage detected